如何在 Pytorch 中可视化网络?

问题描述 投票:0回答:5
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.models.vgg import model_urls
from torchviz import make_dot

batch_size = 3
learning_rate =0.0002
epoch = 50

resnet = models.resnet50(pretrained=True)
print resnet
make_dot(resnet)

我想从 pytorch 模型中可视化

resnet
。我该怎么做?我尝试使用
torchviz
但它给出了一个错误:

'ResNet' object has no attribute 'grad_fn'
python pytorch
5个回答
97
投票

这里是使用不同工具的三种不同的图形可视化。

为了生成示例可视化,我将使用一个简单的 RNN 来执行取自在线教程的情感分析:

class RNN(nn.Module):

    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):

        super().__init__()
        self.embedding  = nn.Embedding(input_dim, embedding_dim)
        self.rnn        = nn.RNN(embedding_dim, hidden_dim)
        self.fc         = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):

        embedding       = self.embedding(text)
        output, hidden  = self.rnn(embedding)

        return self.fc(hidden.squeeze(0))

这是输出,如果你

print()
模型。

RNN(
  (embedding): Embedding(25002, 100)
  (rnn): RNN(100, 256)
  (fc): Linear(in_features=256, out_features=1, bias=True)
)

以下是三种不同可视化工具的结果。

对于所有这些,您需要具有可以通过模型的

forward()
方法的虚拟输入。获取此输入的一种简单方法是从您的 Dataloader 中检索一个批次,如下所示:

batch = next(iter(dataloader_train))
yhat = model(batch.text) # Give dummy batch to forward().

火炬可视化

https://github.com/szagoruyko/pytorchviz

我相信这个工具使用向后传递生成它的图形,所以所有的盒子都使用 PyTorch 组件进行反向传播。

from torchviz import make_dot

make_dot(yhat, params=dict(list(model.named_parameters()))).render("rnn_torchviz", format="png")

此工具生成以下输出文件:

这是唯一明确提到我的模型中的三个层的输出,

embedding
rnn
fc
。运算符名称取自反向传递,因此其中一些难以理解。

隐藏层

https://github.com/waleedka/hiddenlayer

我相信这个工具使用前向传递。

import hiddenlayer as hl

transforms = [ hl.transforms.Prune('Constant') ] # Removes Constant nodes from graph.

graph = hl.build_graph(model, batch.text, transforms=transforms)
graph.theme = hl.graph.THEMES['blue'].copy()
graph.save('rnn_hiddenlayer', format='png')

这是输出。我喜欢蓝色的阴影。

我发现输出的细节太多,混淆了我的架构。例如,为什么

unsqueeze
被提到这么多次?

耐创

https://github.com/lutzroeder/netron

此工具是适用于 Mac、Windows 和 Linux 的桌面应用程序。它依赖于首先导出为 ONNX 格式 的模型。然后应用程序读取 ONNX 文件并呈现它。然后可以选择将模型导出到图像文件。

input_names = ['Sentence']
output_names = ['yhat']
torch.onnx.export(model, batch.text, 'rnn.onnx', input_names=input_names, output_names=output_names)

这是模型在应用程序中的样子。我认为这个工具非常灵巧:您可以缩放和平移,还可以钻取图层和运算符。我发现的唯一缺点是它只能进行垂直布局。


39
投票

make_dot
需要一个变量(即带有
grad_fn
的张量),而不是模型本身。
尝试:

x = torch.zeros(1, 3, 224, 224, dtype=torch.float, requires_grad=False)
out = resnet(x)
make_dot(out)  # plot graph of variable, not of a nn.Module

16
投票

您可以查看 PyTorchViz (https://github.com/szagoruyko/pytorchviz),“用于创建 PyTorch 执行图和轨迹可视化的小包。”


13
投票

如果要保存图像,请使用

torchviz
执行此操作:

# http://www.bnikolic.co.uk/blog/pytorch-detach.html

import torch
from torchviz import make_dot

x=torch.ones(10, requires_grad=True)
weights = {'x':x}

y=x**2
z=x**3
r=(y+z).sum()

make_dot(r).render("attached", format="png")

得到的图片截图:

来源:http://www.bnikolic.co.uk/blog/pytorch-detach.html


11
投票

这可能是一个迟到的答案。但是,特别是随着

__torch_function__
的发展,可以获得更好的可视化效果。你可以在这里试试我的项目,torchview

对于您的 resnet50 示例,您查看 colab notebook,这里 我在这里展示了 resnet18 模型的可视化。 resnet18的图像由以下代码产生

import torchvision
from torchview import draw_graph

model_graph = draw_graph(resnet18(), input_size=(1,3,224,224), expand_nested=True)
model_graph.visual_graph

它还接受范围广泛的输出/输入类型(例如列表、字典)

© www.soinside.com 2019 - 2024. All rights reserved.