我有一个 API 调用:
<API_URL>
,它将返回有效负载。每个 API 调用对应 1 条记录,应将其提取到表中。
我的表中有 200,000 条记录需要摄取,所以我通过一条一条地摄取来循环运行它们,花了将近 5 个小时。我检查了日志,发现所有文件系统更新、快照、日志更新都需要时间。对于每次插入,此过程都会重复,即 200,000 次,因此需要很长时间才能完成处理少量记录。
因此,我创建了一个空的 DataFrame,然后不断将每个 api 调用的输出附加到其中,这样我就拥有一个数据帧,在其中累积所有数据,然后将其简单地写入表中。 这就是我在 Python 中实现多线程的方法。
def prepare_empty_df(schema, spark: SparkSession) -> DataFrame:
empty_rdd = spark.sparkContext.emptyRDD()
empty_df = spark.createDataFrame(empty_rdd, schema)
return empty_df
class RunApiCalls:
def __init__(self, df: DataFrame=None):
self.finalDf = df
def do_some_transformations(df: DataFrame) -> DataFrame:
return do_some_transformation_output_dataframe
def get_json(self, spark, PARAMETER):
try:
token_headers = create_bearer_token()
session = get_session()
api_response = session.get(f'API_URL/?API_PARAMETER={PARAMETER}', headers=token_headers)
print(f'API call: API_URL/?API_PARAMETER={PARAMETER} -> Status code: {api_response.status_code}')
api_json_object = json.loads(api_response.text)
string_data = json.dumps(api_json_object)
json_df = spark.createDataFrame([(1, string_data)],["id","value"])
api_dataframe = do_some_transformations(json_df)
self.finalDf = self.finalDf.unionAll(api_dataframe)
except Exception as error:
traceback.print_exc()
def api_main(self, spark, batch_size, state_names) -> DataFrame:
try:
for i in range(0, len(state_names), batch_size):
sub_list = state_names[i:i + batch_size]
threads = []
for index in range(len(sub_list)):
t = threading.Thread(target=self.get_json, name=str(index), args=(spark, sub_list[index]))
threads.append(t)
t.start()
for index, thread in enumerate(threads):
thread.join()
print(f"All Threads completed for this sub_list{i}")
return self.finalDf
except Exception as e:
traceback.print_exc()
if __name__ == "__main__":
spark = SparkSession.builder.appName('SOME_APP_NAME').getOrCreate()
batch_size = 15
empty_df = prepare_empty_df(schema=schema, spark=spark)
print('Created Empty Dataframe')
api_param_list = get_list()
print(f'api param list: {api_param_list}')
api_call = RunApiCalls(df=empty_df)
final_df = api_call.api_main(spark=spark, batch_size=batch_size, state_names=api_param_list)
final_df.write.mode('append').saveAsTable("some_database.some_tablebname")
当我提交此代码时,我可以看到在后台运行的多线程及其日志。 日志:
All Threads completed for this sub_list0
API call: API_URL/?API_PARAMETER={PARAMETER} -> Status code: 200
....
API call: API_URL/?API_PARAMETER={PARAMETER} -> Status code: 200
All Threads completed for this sub_list15
API call: API_URL/?API_PARAMETER={PARAMETER} -> Status code: 200
....
API call: API_URL/?API_PARAMETER={PARAMETER} -> Status code: 200
All Threads completed for this sub_list30
API call: API_URL/?API_PARAMETER={PARAMETER} -> Status code: 200
..
..
..
API call: API_URL/?API_PARAMETER={PARAMETER} -> Status code: 200
All Threads completed for this sub_list199985
在
do_some_transformations()
中,我什么也没做,只是将架构应用于 json 输出。
当我将数据转储到表中时也没有错误/失败。
但是当我检查表中的数据时,我没有看到所有记录。
select count(*) from some_database.some_tablebname
我只看到1735
记录(结果如屏幕截图所示)。
而且,每次运行所有线程时,这个计数都会有所不同。有时它变成5000
,有时它是8000
,从今以后。
所有API调用的状态码都是
200
,我还打印了一次API响应的内容,发现API调用确实正在返回数据。
如果我在这里犯了任何错误,有人可以告诉我,以便我可以看到全部数据量,即列表中的行数。
您正在更新线程中的共享资源,这些资源可能会被中断,从而丢失数据。考虑
self.finalDf = self.finalDf.unionAll(api_dataframe)
假设当该线程的右侧正在运行时,另一个线程进行相同的调用。当它最终完成时,它会将旧的过时数据分配给
self.finalDf
。相反,您可以在此步骤中添加锁来保护数据帧。在下面的示例中,我添加了 finalDfLock
。
import threading
def prepare_empty_df(schema, spark: SparkSession) -> DataFrame:
empty_rdd = spark.sparkContext.emptyRDD()
empty_df = spark.createDataFrame(empty_rdd, schema)
return empty_df
class RunApiCalls:
def __init__(self, df: DataFrame=None):
self.finalDf = df
self.finalDfLock = threading.RLock()
def do_some_transformations(df: DataFrame) -> DataFrame:
return do_some_transformation_output_dataframe
def get_json(self, spark, PARAMETER):
try:
token_headers = create_bearer_token()
session = get_session()
api_response = session.get(f'API_URL/?API_PARAMETER={PARAMETER}', headers=token_headers)
print(f'API call: API_URL/?API_PARAMETER={PARAMETER} -> Status code: {api_response.status_code}')
api_json_object = json.loads(api_response.text)
string_data = json.dumps(api_json_object)
json_df = spark.createDataFrame([(1, string_data)],["id","value"])
api_dataframe = do_some_transformations(json_df)
with self.finalDfLock:
self.finalDf = self.finalDf.unionAll(api_dataframe)
except Exception as error:
traceback.print_exc()
def api_main(self, spark, batch_size, state_names) -> DataFrame:
try:
for i in range(0, len(state_names), batch_size):
sub_list = state_names[i:i + batch_size]
threads = []
for index in range(len(sub_list)):
t = threading.Thread(target=self.get_json, name=str(index), args=(spark, sub_list[index]))
threads.append(t)
t.start()
for index, thread in enumerate(threads):
thread.join()
print(f"All Threads completed for this sub_list{i}")
return self.finalDf
except Exception as e:
traceback.print_exc()
if __name__ == "__main__":
spark = SparkSession.builder.appName('SOME_APP_NAME').getOrCreate()
batch_size = 15
empty_df = prepare_empty_df(schema=schema, spark=spark)
print('Created Empty Dataframe')
api_param_list = get_list()
print(f'api param list: {api_param_list}')
api_call = RunApiCalls(df=empty_df)
final_df = api_call.api_main(spark=spark, batch_size=batch_size, state_names=api_param_list)
final_df.write.mode('append').saveAsTable("some_database.some_tablebname")