Pickle AttributeError:无法在 <module '__main__' from '<input>'>

问题描述 投票:0回答:1

我编写了一个自定义类来表示集成模型,并且我想对其进行pickle以供以后使用。以下是我如何构建模型并使用 pickle 保存对象:

import pickle
from typing import Any


class EnsembleModel:
    def __init__(self, estimators: List[Any]):
         return ...

   def fit(self, ...):
         return ...

    def predict(self, ...):
         return ...


ensemble_model = EnsembleModel(estimators=[est_1, est_2, est_3], ...)
ensemble_model.fit(X_train, y_train)

with open("ensemble-model.mdl", "wb") as f:
    pickle.dump(ensemble_model, f)

现在,我想在另一个代码中使用这个二进制对象“ensemble-model.mdl”。我有一个更通用的类,比如

MyModel
,来加载和表示此类模型(拥有此类的原因不是这里的重点)。正如你所看到的,这个类负责 unpickle EnsembleModel 对象:

class MyModel:
    model: Any = None
    _probability: float = None
    _predict_method: str = None
    _predict_proba: Any = None

    def __init__(self, model_path: str, model_name: str = 'MyModel', threshold: float = 0.5) -> None:
        with open(model_path, 'rb') as f:
            self.model = pickle.load(f)
        
        ...

我将两个类

EnsembleModel
MyModel
保留在名为
my_package
(my_package/model.py) 的单独包中的单个脚本中,并将此包安装在我的主脚本所在的虚拟环境中运行。

my_package:
- __init__.py
- model.py
    -> EnsembleModel
    -> MyModel
- ...

这是我的主脚本,我需要使用二进制对象

“ensemble-model.mdl”
: 初始化 MyModel

的实例

main.py

import asyncio

from my_package.model import EnsembleModel, MyModel


async def main():
    model = MyModel(
        model_name="My Ensemble Model",
        model_path="ensemble-model.mdl",
        threshold=0.5
    )


loop = asyncio.get_event_loop()
loop.run_until_complete(main())
loop.close()

运行 main.py 时出现此错误:

Traceback (most recent call last):
  File "C:\Program Files\JetBrains\PyCharm Community Edition 2022.1.3\plugins\python-ce\helpers\pydev\pydevconsole.py", line 364, in runcode
    coro = func()
  File "<input>", line 2, in <module>
  File "C:\Users\mehdi\AppData\Local\Programs\Python\Python38\lib\asyncio\base_events.py", line 616, in run_until_complete
    return future.result()
  File "D:\my_app\my_pipeline\__init__.py", line 108, in main
    MyModel(
  File "D:\my_app\venv\lib\site-packages\my_package\model.py", line 92, in __init__
    self.model = pickle.load(f)
AttributeError: Can't get attribute 'EnsembleModel' on <module '__main__' from '<input>'>

该错误似乎是由于酸洗/反酸洗过程造成的,但我不知道如何修复它。关于如何修复此错误有什么想法吗?

python python-asyncio pickle
1个回答
0
投票

这是一个姗姗来迟的答案,但值得一提。我可以通过覆盖

'model'
模块来解决问题,因为它是使用
sys
导入到训练环境中的。

import asyncio
import sys

import my_package
from my_package.model import EnsembleModel, MyModel

sys.modules['model'] = my_package.model


async def main():
    model = MyModel(
        model_name="My Ensemble Model",
        model_path="ensemble-model.mdl",
        threshold=0.5
    )


loop = asyncio.get_event_loop()
loop.run_until_complete(main())
loop.close()
© www.soinside.com 2019 - 2024. All rights reserved.