是否可以查看torch预训练网络的代码

问题描述 投票:0回答:1

如果您在阅读标题时认为自己是个菜鸟 - 是的,我就是。

我用谷歌搜索过,但没有找到任何指南可以让我查看预训练的火炬神经网络是如何设计/编码的。我已经下载了预训练的网络(文件格式.t7)并且安装了 torch。谁能帮我查看它的编码方式(使用的过滤器大小、使用的参数等)?

可能它不在谷歌上,因为这是不可能的?我们很乐意回答您的任何其他问题或任何不清楚的地方。

谢谢你。

pytorch pre-trained-model
1个回答
1
投票

我认为无法获取底层代码。但是您只需使用 print 即可获得模型的摘要,其中包括层和主要参数。

model = SumModel(vocab_size=vocab_size, hiddem_dim=hidden_dim, batch_size=batch_size)
# saving model
torch.save(model, 'test_model.save')
# print summary of original
print(' - original model summary:')
print(model)
print()

# load saved model
loaded_model = torch.load('test_model.save')
# print summary of loaded model
print(' - loaded model summary:')
print(loaded_model)

这将输出如下所示的摘要。

  - original model summary:
SumModel(
  (word_embedding): Embedding(530734, 128)
  (encoder): LSTM(128, 128, batch_first=True)
  (decoder): LSTM(128, 128, batch_first=True)
  (output_layer): Linear(in_features=128, out_features=530734, bias=True)
)

 - loaded model summary:
SumModel(
  (word_embedding): Embedding(530734, 128)
  (encoder): LSTM(128, 128, batch_first=True)
  (decoder): LSTM(128, 128, batch_first=True)
  (output_layer): Linear(in_features=128, out_features=530734, bias=True)
)

使用 Pytorch 0.4.0 进行测试

如您所见,原始模型和加载模型的输出是一致的。

我希望这有帮助。

© www.soinside.com 2019 - 2024. All rights reserved.