我编写了一个自定义类来表示集成模型,并且我想对其进行pickle以供以后使用。以下是我如何构建模型并使用 pickle 保存对象:
import pickle
from typing import Any
class EnsembleModel:
def __init__(self, estimators: List[Any]):
return ...
def fit(self, ...):
return ...
def predict(self, ...):
return ...
ensemble_model = EnsembleModel(estimators=[est_1, est_2, est_3], ...)
ensemble_model.fit(X_train, y_train)
with open("ensemble-model.mdl", "wb") as f:
pickle.dump(ensemble_model, f)
现在,我想在另一个代码中使用这个二进制对象“ensemble-model.mdl”。我有一个更通用的类,比如
MyModel
,来加载和表示此类模型(拥有此类的原因不是这里的重点)。正如你所看到的,这个类负责 unpickle EnsembleModel 对象:
class MyModel:
model: Any = None
_probability: float = None
_predict_method: str = None
_predict_proba: Any = None
def __init__(self, model_path: str, model_name: str = 'MyModel', threshold: float = 0.5) -> None:
with open(model_path, 'rb') as f:
self.model = pickle.load(f)
...
我将两个类
EnsembleModel
和 MyModel
保留在名为 my_package
(my_package/model.py) 的单独包中的单个脚本中,并将此包安装在我的主脚本所在的虚拟环境中运行。
my_package:
- __init__.py
- model.py
-> EnsembleModel
-> MyModel
- ...
这是我的主脚本,我需要使用二进制对象
“ensemble-model.mdl”: 初始化
MyModel
的实例
main.py
import asyncio
from my_package.model import EnsembleModel, MyModel
async def main():
model = MyModel(
model_name="My Ensemble Model",
model_path="ensemble-model.mdl",
threshold=0.5
)
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
loop.close()
运行 main.py 时出现此错误:
Traceback (most recent call last):
File "C:\Program Files\JetBrains\PyCharm Community Edition 2022.1.3\plugins\python-ce\helpers\pydev\pydevconsole.py", line 364, in runcode
coro = func()
File "<input>", line 2, in <module>
File "C:\Users\mehdi\AppData\Local\Programs\Python\Python38\lib\asyncio\base_events.py", line 616, in run_until_complete
return future.result()
File "D:\my_app\my_pipeline\__init__.py", line 108, in main
MyModel(
File "D:\my_app\venv\lib\site-packages\my_package\model.py", line 92, in __init__
self.model = pickle.load(f)
AttributeError: Can't get attribute 'EnsembleModel' on <module '__main__' from '<input>'>
该错误似乎是由于酸洗/反酸洗过程造成的,但我不知道如何修复它。关于如何修复此错误有什么想法吗?
这是一个姗姗来迟的答案,但值得一提。我可以通过覆盖
'model'
模块来解决问题,因为它是使用 sys
导入到训练环境中的。
import asyncio
import sys
import my_package
from my_package.model import EnsembleModel, MyModel
sys.modules['model'] = my_package.model
async def main():
model = MyModel(
model_name="My Ensemble Model",
model_path="ensemble-model.mdl",
threshold=0.5
)
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
loop.close()