在Python中,Pytorch nn.Module的派生类不能通过模块导入来加载。

问题描述 投票:0回答:1

使用Python 3.6和Pytorch 1.3.1。我注意到,当整个模块被导入到另一个模块中时,一些保存的nn.Modules不能被加载。举个例子,这里是一个最小工作示例的模板。

#!/usr/bin/env python3
#encoding:utf-8
# file 'dnn_predict.py'

from torch import nn
class NN(nn.Module):##NN network
    # Initialisation and other class methods

networks=[torch.load(f=os.path.join(resource_directory, 'nn-classify-cpu_{fold}.pkl'.format(fold=fold))) for fold in range(5)]
...
if __name__=='__main__':
    # Some testing snippets
    pass

当我直接在shell中运行它时,整个文件工作得很好。但是,当我想使用这个类,并使用这段代码在另一个文件中加载神经网络时,它失败了。

#!/usr/bin/env python3
#encoding:utf-8
from dnn_predict import *

错误内容如下 AttributeError: Can't get attribute 'NN' on <module '__main__'>

在Pytorch中加载保存的变量或导入模块的情况是否与其他常见的Python库不同?如果能得到一些帮助或指出根本原因,我将非常感激。

pytorch python-import python-module python-class
1个回答
0
投票

当你保存一个模型时 torch.save(model, PATH) 整个对象会被序列化,用 pickle,它并不保存类本身,而是保存一个包含类的文件的路径,因此在加载模型时,需要完全相同的目录和文件结构才能找到正确的类。当运行一个Python脚本时,该文件的模块是 __main__因此,如果你想加载该模块,你的 NN 类必须在你正在运行的脚本中定义。

这是非常不灵活的,所以推荐的方法是不要保存整个模型,而只保存状态字典,它只保存模型的参数。

# Save the state dictionary of the model
torch.save(model.state_dict(), PATH)

之后,可以加载状态字典并应用到你的模型中。

from dnn_predict import NN

# Create the model (will have randomly initialised parameters)
model = NN()

# Load the previously saved state dictionary
state_dict = torch.load(PATH)

# Apply the state dictionary to the model
model.load_state_dict(state_dict)

有关状态字典和保存加载模型的更多详情。PyTorch - 保存和加载模型

© www.soinside.com 2019 - 2024. All rights reserved.