我正在使用Airflow整理我的数据管道。在一项任务中,我试图从S3加载一个腌制的对象(RouteModel
实例):
def read_file_from_s3(bucket, file):
from inference.route_model import RouteModel
s3_loader = S3Client(bucket, None)
buffer = s3_loader.get_file(file)
data = pickle.loads(buffer.read())
这给我这个错误:
Traceback (most recent call last):
File "/Users/cyrusghazanfar/Desktop/startup-studio/pilota_project/pilota_ml/env/lib/python3.6/site-packages/airflow/models/taskinstance.py", line 926, in _run_raw_task
result = task_copy.execute(context=context)
File "/Users/cyrusghazanfar/Desktop/startup-studio/pilota_project/pilota_ml/env/lib/python3.6/site-packages/airflow/operators/python_operator.py", line 113, in execute
return_value = self.execute_callable()
File "/Users/cyrusghazanfar/Desktop/startup-studio/pilota_project/pilota_ml/env/lib/python3.6/site-packages/airflow/operators/python_operator.py", line 118, in execute_callable
return self.python_callable(*self.op_args, **self.op_kwargs)
File "/Users/cyrusghazanfar/Desktop/startup-studio/pilota_project/pilota_ml/inference/predict.py", line 43, in get_pred_for_flight
pred_state, pred_state_prob, pred_dt = tst_pipeline.get_prediction(format_pred_od)
File "/Users/cyrusghazanfar/Desktop/startup-studio/pilota_project/pilota_ml/inference/pipeline.py", line 174, in get_prediction
route_model = self.rm_loader.get_model(self.rm_dict[r_key]['rm_key'])
File "/Users/cyrusghazanfar/Desktop/startup-studio/pilota_project/pilota_ml/inference/dataloader.py", line 40, in get_model
route_model = read_file_from_s3(self.loc, fname)
File "/Users/cyrusghazanfar/Desktop/startup-studio/pilota_project/pilota_ml/inference/dataloader.py", line 96, in read_file_from_s3
data = pickle.loads(buffer.read())
AttributeError: Can't get attribute 'RouteModel' on <module '__main__' from '/Users/cyrusghazanfar/Desktop/startup-studio/pilota_project/pilota_ml/env/bin/airflow'>
[使用自定义类时,被腌制的类必须出现在读取腌制过程的名称空间中,在本例中为Airflow。
注意:
我无法更改我腌制文件的方式
请帮助:)
为了解决这个问题,我需要编写我自己的自定义unpickler,在这里我也显式返回pickle文件所引用的特定实例的自定义类:
class CustomUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if name == 'RouteModel':
from inference.route_model import RouteModel
return RouteModel
return super().find_class(module, name)
data = CustomUnpickler(io.BytesIO(buffer.read())).load()