PyTorch 和 Pickle 错误:AttributeError:无法在 <module '__main__' from '/load_model.py'>

问题描述 投票:0回答:1

我使用 PyTorch 创建了一个神经网络模型来解决时间序列预测问题。我已经使用 Pickle 保存了模型。 当加载模型以检查测试数据时,它抛出属性错误。 下面,我提供了代码的相关部分:

import torch
import torch.nn as nn
import torch.nn.functional as F

class twoD_predict(nn.Module):

  def __init__(self):

    super().__init__()

  def forward(self,x):
  ...

  def train(self,
        epochs = 100):
  ...

obj1 = twoD_predict()

obj1.train()

我使用 Pickle 保存模型,如下所示:

import pickle
filename = (f"{column_names[0]}.sav")
pickle.dump(obj1, open(filename, 'wb'))

但是,当我尝试使用以下代码加载模型时:

import pickle
if __name__ == '__main__':
    with open("model.sav", 'rb') as file:
        model = pickle.load(file)

我遇到错误:

"AttributeError: Can't get attribute "twoD_predict" on <module '__main__' from '/load_model.py'>"

有人可以帮忙吗?

python-3.x pytorch pickle
1个回答
0
投票

您是否尝试将

twoD_predict
导入到要加载模型的文件中? pickled 版本仍然需要复制如何重新创建模型的结构(例如,您必须将模型的定义加载到当前文件中)。如果您不想这样做,请考虑以 TorchScript 格式导出模型。

© www.soinside.com 2019 - 2024. All rights reserved.