取消保存的 pytorch 模型会抛出 AttributeError: Can't get attribute 'Net' on <module '__main__' despite adding class definition inline

问题描述 投票:0回答:8

我正在尝试在 Flask 应用程序中提供 pytorch 模型。当我之前在 jupyter 笔记本上运行此代码时,此代码是有效的,但现在我在虚拟环境中运行此代码,显然它无法获取属性“Net”,即使类定义就在那里。所有其他类似的问题都告诉我在同一脚本中添加已保存模型的类定义。但它仍然不起作用。 torch 版本是 1.0.1(其中保存的模型以及 virtualenv 都经过训练) 我究竟做错了什么? 这是我的代码。

import os
import numpy as np
from flask import Flask, request, jsonify 
import requests

import torch
from torch import nn
from torch.nn import functional as F


MODEL_URL = 'https://storage.googleapis.com/judy-pytorch-model/classifier.pt'


r = requests.get(MODEL_URL)
file = open("model.pth", "wb")
file.write(r.content)
file.close()

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = self.fc3(x)

        return F.log_softmax(x, dim=-1)

model = torch.load('model.pth')

app = Flask(__name__)

@app.route("/")
def hello():
    return "Binary classification example\n"

@app.route('/predict', methods=['GET'])
def predict():


    x_data = request.args['x_data']

    x_data =  x_data.split()
    x_data = list(map(float, x_data))

    sample = np.array(x_data) 

    sample_tensor = torch.from_numpy(sample).float()

    out = model(sample_tensor)

    _, predicted = torch.max(out.data, -1)

    if predicted.item() == 0: 
         pred_class = "Has no liver damage - ", predicted.item()
    elif predicted.item() == 1:
        pred_class = "Has liver damage - ", predicted.item()

    return jsonify(pred_class)

这是完整的回溯:

Traceback (most recent call last):
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/bin/flask", line 10, in <module>
    sys.exit(main())
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 894, in main
    cli.main(args=args, prog_name=name)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 557, in main
    return super(FlaskGroup, self).main(*args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 717, in main
    rv = self.invoke(ctx)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 1137, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 956, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 555, in invoke
    return callback(*args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/decorators.py", line 64, in new_func
    return ctx.invoke(f, obj, *args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 555, in invoke
    return callback(*args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 767, in run_command
    app = DispatchingApp(info.load_app, use_eager_loading=eager_loading)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 293, in __init__
    self._load_unlocked()
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 317, in _load_unlocked
    self._app = rv = self.loader()
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 372, in load_app
    app = locate_app(self, import_name, name)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 235, in locate_app
    __import__(module_name)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/app.py", line 34, in <module>
    model = torch.load('model.pth')
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/torch/serialization.py", line 368, in load
    return _load(f, map_location, pickle_module)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/torch/serialization.py", line 542, in _load
    result = unpickler.load()
AttributeError: Can't get attribute 'Net' on <module '__main__' from '/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/bin/flask'>

不能解决我的问题。我不想改变我坚持模型的方式。 torch.save() 在虚拟环境之外对我来说工作得很好。我不介意将类定义添加到脚本中。尽管如此,我还是想看看是什么原因导致了错误。

python pickle pytorch
8个回答
8
投票

(这是部分答案)

我不认为

torch.save(model,'model.pt')
在命令提示符下工作,或者当模型从作为
'__main__'
运行的一个脚本保存并从另一个脚本加载时工作。

原因是torch必须自动加载用于保存文件的模块,并且它从

__name__
获取模块名称。

现在是部分部分:目前还不清楚如何解决这个问题,特别是当你混合了 virtualenvs 时。

感谢Jatentaki开始朝这个方向进行对话。


5
投票

我知道我回答这个问题已经晚了。但找到了一种从另一个包而不是“__main__”加载模型的方法

在加载模块之前,如果按如下方式动态设置属性,则它将起作用。

import __main__
setattr(__main__, "Net", Net)
model = torch.load(os.path.join(parent_dir,"<path to pickle>"), map_location=torch.device("cpu"))

注意:如果“__main__”是二进制文件,则此 hack 将不起作用。


3
投票

在这里,模型的保存和加载是使用 pickle 在后台完成的。这种方法的缺点是序列化数据绑定到特定的类以及保存模型时使用的确切目录结构。这样做的原因是因为Pickle 不保存模型类本身。相反,它保存包含该类的文件的路径,该路径在加载时使用。因此,在其他项目中使用或重构后,您的代码可能会以各种方式中断,请参阅更多信息

解决方案: 而不是使用

torch.save(model, PATH)
, 使用以下内容:

# Saving model
model_scripted = torch.jit.script(model) # Export to TorchScript
model_scripted.save('model_scripted.pt') # Save

# Loading model
model = torch.jit.load('model_scripted.pt')
model.eval()

1
投票

首先我初始化了一个空模型,然后加载了保存的模型,这由于某种原因解决了问题。


1
投票

解决您的问题的一个简单方法是您需要在加载模型之前定义“class Net(nn.Module):”。这将解决这个问题


0
投票

简单的解决方案:

  1. 你只需要创建一个类
    Net(nn.Module)
    的实例,如下所示,然后它就可以正常运行了。
  2. 我也遇到过同样的问题,并通过这些简单的步骤解决了。
import torch
from torch import nn
from torch.nn import functional as F


MODEL_URL = 'https://storage.googleapis.com/judy-pytorch-model/classifier.pt'


r = requests.get(MODEL_URL)
file = open("model.pth", "wb")
file.write(r.content)
file.close()

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = self.fc3(x)

        return F.log_softmax(x, dim=-1)

model = Net()#<---------------------------- Extra thing added
model = torch.load('model.pth', , map_location=torch.device('cpu'))#<---- if running on a CPU, else 'cuda'

app = Flask(__name__)

@app.route("/")
def hello():
    return "Binary classification example\n"

@app.route('/predict', methods=['GET'])
def predict():


    x_data = request.args['x_data']

    x_data =  x_data.split()
    x_data = list(map(float, x_data))

    sample = np.array(x_data) 

    sample_tensor = torch.from_numpy(sample).float()

    out = model(sample_tensor)

    _, predicted = torch.max(out.data, -1)

    if predicted.item() == 0: 
         pred_class = "Has no liver damage - ", predicted.item()
    elif predicted.item() == 1:
        pred_class = "Has liver damage - ", predicted.item()

    return jsonify(pred_class)

0
投票

我最近偶然发现了同样的问题,并通过不同的保存模型的方式解决了它。

当我这样保存时:

torch.save(model, 'model_name.pth')

然后像这样加载它:

loaded_model = torch.load('model_name.pth')

在 Flask 应用程序中,我收到有关 Flask 无法找到我在训练期间声明模型的自定义类的错误。即使这个模型类的代码是在模型加载行之前复制到 Flask 应用程序代码中的。

但是,当我将保存模型的代码更改为:

torch.save(loaded_model.state_dict(), 'model_name.pth')

并将代码加载到:

loaded_model = TheModelClass(*args, **kwargs)
loaded_model.load_state_dict(torch.load('model_name.pth'))

一切顺利。 (当然,就像文档一样,您需要在烧瓶应用程序代码中加载模型之前声明自定义模型类。)

希望这有帮助!


-1
投票

这可能不是一个非常受欢迎的答案,但是,我发现

dill
包在使我的代码工作方面非常一致。对我来说,我什至没有尝试加载模型,我正在尝试解压一个对我的东西有帮助的自定义对象,但由于某种原因找不到它。我不知道为什么,但根据我的经验,莳萝似乎是腌制的更好选择:

    # - path to files
    path = Path(path2dataset).expanduser()
    path2file_data_prep = Path(path2file_data_prep).expanduser()
    # - create dag dataprep obj
    print(f'path to data set {path=}')
    dag_prep = SplitDagDataPreparation(path)
    # - save data prep splits object
    print(f'saving to {path2file_data_prep=}')
    torch.save({'data_prep': dag_prep}, path2file_data_prep, pickle_module=dill)
    # - load the data prep splits object to test it loads correctly
    db = torch.load(path2file_data_prep, pickle_module=dill)
    db['data_prep']
    print(db)
    return path2file_data_prep
© www.soinside.com 2019 - 2024. All rights reserved.