如何决定torchsummary.summary(model=model.policy, input_size=(int, int, int))的'input_size'参数?

问题描述 投票:0回答:1

这是我的 CNN 网络,由“print(model.policy)”打印:

CnnPolicy(
  (actor): Actor(
    (features_extractor): CustomCNN(
      (cnn): Sequential(
        (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU()
        (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
        (3): ReLU()
        (4): Flatten(start_dim=1, end_dim=-1)
      )
      (linear): Sequential(
        (0): Linear(in_features=6, out_features=128, bias=True)
        (1): ReLU()
      )
    )
    (mu): Sequential(
      (0): Linear(in_features=128, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=3, bias=True)
      (5): Tanh()
    )
  )

当我尝试使用 torchsummary.summary(model=model.policy, input_size=(1, 32, 32)) 打印网络架构时。我收到以下错误: 运行时错误:mat1 和 mat2 形状无法相乘(2x50176 和 6x128)

我尝试了很多“input_size”组合,但都是错误的。

我想知道如何选择'input-size'参数?

python networking pytorch conv-neural-network
1个回答
0
投票

这不是总结的问题,而是你网络的问题。我认为由于

Flatten()
层及其第二个参数,您对层数感到困惑。

我建议您逐层组装网络,并通过输入随机

x = torch.from_numpy(np.random.rand(batch_dim, channel_dim, spatial1, spatial2)
来测试它,看看它是否可以很好地协同工作。

Flatten 通常用于展平通道和空间维度,但不展平批量维度。您将通道和一个空间维度展平,这可能不是您想要的。

此外,检查您的输入通道是否适合之前的输出通道。如果您提供一个复制粘贴的示例,而不仅仅是结构,我可以调试您的网络。

祝你好运!

© www.soinside.com 2019 - 2024. All rights reserved.