使用 pytorch 在本地加载 *.pth 检查点

问题描述 投票:0回答:1

我尝试从本地文件离线加载 VGG19 检查点,而不是常规的 pytorch 方法(在线下载),但遇到了问题。 所以基本上我正在这样做: https://pytorch.org/tutorials/advanced/neural_style_tutorial.html

而不是

cnn = models.vgg19(pretrained=True).features.to(device).eval()

这与其他方法配合得很好,我想从本地 *.pth 文件(相同的“vgg19-dcbb9e9d.pth”,放入特定文件夹中)开始工作,然后我尝试使用此方法:

checkpoint = torch.load('models/vgg19-dcbb9e9d.pth')
cnn = models.vgg19()
cnn.load_state_dict(checkpoint)
cnn.eval()

但随后出现错误

---> 32             raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
     33 
     34         model.add_module(name, layer)

RuntimeError: Unrecognized layer: Sequential

基本上模型未正确加载或读取,因为它似乎没有找到代码正在寻找的层。 我有什么遗漏的吗?

python pytorch conv-neural-network vgg-net
1个回答
0
投票

也许不需要分类器层。

С检查两者:

print("Model's state_dict:")
for param_tensor in cnn.state_dict():
    print(param_tensor, "\t", cnn.state_dict()[param_tensor].size())

如果您只需要功能,那么

model = copy.deepcopy(cnn.features)
model.to(device)

for param in model.parameters():
    param.requires_grad = False
© www.soinside.com 2019 - 2024. All rights reserved.