如何通过 API 调用修复多线程中的数据丢失并将数据附加到 Spark Dataframe?

问题描述 投票:0回答:1

我有一个 API 调用:

<API_URL>
,它将返回有效负载。每个 API 调用对应 1 条记录,应将其提取到表中。

我的表中有 200,000 条记录需要摄取,所以我通过一条一条地摄取来循环运行它们,花了将近 5 个小时。我检查了日志,发现所有文件系统更新、快照、日志更新都需要时间。对于每次插入,此过程都会重复,即 200,000 次,因此需要很长时间才能完成处理少量记录。

因此,我创建了一个空的 DataFrame,然后不断将每个 api 调用的输出附加到其中,这样我就拥有一个数据帧,在其中累积所有数据,然后将其简单地写入表中。 这就是我在 Python 中实现多线程的方法。

def prepare_empty_df(schema, spark: SparkSession) -> DataFrame:
    empty_rdd = spark.sparkContext.emptyRDD()
    empty_df = spark.createDataFrame(empty_rdd, schema)
    return empty_df

class RunApiCalls:
    def __init__(self, df: DataFrame=None):
        self.finalDf = df


    def do_some_transformations(df: DataFrame) -> DataFrame:
        return do_some_transformation_output_dataframe


    def get_json(self, spark, PARAMETER):
        try:
            token_headers = create_bearer_token()
            session = get_session()
            api_response = session.get(f'API_URL/?API_PARAMETER={PARAMETER}', headers=token_headers)
            print(f'API call: API_URL/?API_PARAMETER={PARAMETER} -> Status code: {api_response.status_code}')
            api_json_object = json.loads(api_response.text)
            string_data = json.dumps(api_json_object)
            json_df = spark.createDataFrame([(1, string_data)],["id","value"])
            api_dataframe = do_some_transformations(json_df)
            self.finalDf = self.finalDf.unionAll(api_dataframe)
        except Exception as error:  
            traceback.print_exc()

    def api_main(self, spark, batch_size, state_names) -> DataFrame:
        try:
            for i in range(0, len(state_names), batch_size):
                sub_list = state_names[i:i + batch_size]
                threads = []
                for index in range(len(sub_list)):
                    t = threading.Thread(target=self.get_json, name=str(index), args=(spark, sub_list[index]))
                    threads.append(t)
                    t.start()
                for index, thread in enumerate(threads):
                    thread.join()
                print(f"All Threads completed for this sub_list{i}")
            return self.finalDf
        except Exception as e:
            traceback.print_exc()



if __name__ == "__main__": 
    spark = SparkSession.builder.appName('SOME_APP_NAME').getOrCreate()
    batch_size = 15
    empty_df = prepare_empty_df(schema=schema, spark=spark)
    print('Created Empty Dataframe')
    api_param_list = get_list()
    print(f'api param list: {api_param_list}')
    api_call = RunApiCalls(df=empty_df)
    final_df = api_call.api_main(spark=spark, batch_size=batch_size, state_names=api_param_list)
    final_df.write.mode('append').saveAsTable("some_database.some_tablebname")

当我提交此代码时,我可以看到在后台运行的多线程及其日志。 日志:

All Threads completed for this sub_list0
API call: API_URL/?API_PARAMETER={PARAMETER} -> Status code: 200
....
API call: API_URL/?API_PARAMETER={PARAMETER} -> Status code: 200
All Threads completed for this sub_list15
API call: API_URL/?API_PARAMETER={PARAMETER} -> Status code: 200
....
API call: API_URL/?API_PARAMETER={PARAMETER} -> Status code: 200
All Threads completed for this sub_list30
API call: API_URL/?API_PARAMETER={PARAMETER} -> Status code: 200
..
..
..
API call: API_URL/?API_PARAMETER={PARAMETER} -> Status code: 200
All Threads completed for this sub_list199985

do_some_transformations()
中,我什么也没做,只是将架构应用于 json 输出。 当我将数据转储到表中时也没有错误/失败。 但是当我检查表中的数据时,我没有看到所有记录。
select count(*) from some_database.some_tablebname
我只看到
1735
记录(结果如屏幕截图所示)。 而且,每次运行所有线程时,这个计数都会有所不同。有时它变成
5000
,有时它是
8000
,从今以后。

所有API调用的状态码都是

200
,我还打印了一次API响应的内容,发现API调用确实正在返回数据。 如果我在这里犯了任何错误,有人可以告诉我,以便我可以看到全部数据量,即列表中的行数。

python apache-spark pyspark python-multiprocessing python-multithreading
1个回答
0
投票

您正在更新线程中的共享资源,这些资源可能会被中断,从而丢失数据。考虑

self.finalDf = self.finalDf.unionAll(api_dataframe)

假设当该线程的右侧正在运行时,另一个线程进行相同的调用。当它最终完成时,它会将旧的过时数据分配给

self.finalDf
。相反,您可以在此步骤中添加锁来保护数据帧。在下面的示例中,我添加了
finalDfLock

import threading

def prepare_empty_df(schema, spark: SparkSession) -> DataFrame:
    empty_rdd = spark.sparkContext.emptyRDD()
    empty_df = spark.createDataFrame(empty_rdd, schema)
    return empty_df

class RunApiCalls:
    def __init__(self, df: DataFrame=None):
        self.finalDf = df
        self.finalDfLock = threading.RLock()

    def do_some_transformations(df: DataFrame) -> DataFrame:
        return do_some_transformation_output_dataframe


    def get_json(self, spark, PARAMETER):
        try:
            token_headers = create_bearer_token()
            session = get_session()
            api_response = session.get(f'API_URL/?API_PARAMETER={PARAMETER}', headers=token_headers)
            print(f'API call: API_URL/?API_PARAMETER={PARAMETER} -> Status code: {api_response.status_code}')
            api_json_object = json.loads(api_response.text)
            string_data = json.dumps(api_json_object)
            json_df = spark.createDataFrame([(1, string_data)],["id","value"])
            api_dataframe = do_some_transformations(json_df)
            with self.finalDfLock:
                self.finalDf = self.finalDf.unionAll(api_dataframe)
        except Exception as error:  
            traceback.print_exc()

    def api_main(self, spark, batch_size, state_names) -> DataFrame:
        try:
            for i in range(0, len(state_names), batch_size):
                sub_list = state_names[i:i + batch_size]
                threads = []
                for index in range(len(sub_list)):
                    t = threading.Thread(target=self.get_json, name=str(index), args=(spark, sub_list[index]))
                    threads.append(t)
                    t.start()
                for index, thread in enumerate(threads):
                    thread.join()
                print(f"All Threads completed for this sub_list{i}")
            return self.finalDf
        except Exception as e:
            traceback.print_exc()



if __name__ == "__main__": 
    spark = SparkSession.builder.appName('SOME_APP_NAME').getOrCreate()
    batch_size = 15
    empty_df = prepare_empty_df(schema=schema, spark=spark)
    print('Created Empty Dataframe')
    api_param_list = get_list()
    print(f'api param list: {api_param_list}')
    api_call = RunApiCalls(df=empty_df)
    final_df = api_call.api_main(spark=spark, batch_size=batch_size, state_names=api_param_list)
    final_df.write.mode('append').saveAsTable("some_database.some_tablebname")
© www.soinside.com 2019 - 2024. All rights reserved.