如何计算CNN分类器中的参数?

问题描述 投票:0回答:1

我已经为 mnist 实现了 CNN 模型。我能够理解如何计算 CNN 不同层的参数和形状,但我想了解如何确定分类器部分中的

in_features
out_features
,特别是
nn.Linear()
。另外,如何在
in_channels
中选择
out_channels
nn.Conv2d

class CNNclf(nn.Module):
def __init__(self):
    super().__init__()
    self.net = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3),
        nn.ReLU(),
        nn.MaxPool2d((2, 2), stride=2),
        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3),
        nn.ReLU(),
        nn.MaxPool2d((2, 2), stride=3),
        nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3),
        nn.ReLU(),
        nn.MaxPool2d((2, 2), stride=2))
    self.clf = nn.Sequential(
        nn.Flatten(),
        nn.Linear(64, 20, bias=True),
        nn.ReLU(),
        nn.Linear(20, 10, bias=True))

def forward(self, x):
    x = self.net(x)
    x = self.clf(x)
    return x
deep-learning pytorch conv-neural-network
1个回答
0
投票

这是一个关于 CNN 的典型问题,有很多类似的帖子,用户都遇到了源于此的相同类型的错误。

运行时错误:mat1 和 mat2 形状无法相乘(

i
x
j
k
x
l

我将在这里提供规范的答案。

使用

nn.Conv2d
时,中间张量将是四维的:
(b, c, h, w)
。卷积在空间上工作,它们在高度和宽度维度上跨张量移动(对于 2D 卷积)。输出通道的数量由卷积层中相互独立运行的滤波器的数量决定。您可以阅读有关卷积层和大小的更多信息:了解卷积层形状

对于 CNN 架构,从卷积部分(特征提取器)转移到全连接层(分类器)时,您必须适应维度的变化。一般情况下,张量形状从 4D 变为 2D。这需要某种形式的空间减少:要么

  • 通过压平,形成
    (b, c*h*w)
    的形状。这可以使用
    nn.Flatten
    ;
  • 来完成
  • 或使用池化操作,例如最大
    nn.MaxPool2d
    或平均池化
    nn.AdaptiveAvgPool2d
    ,导致
    (b, c, h', w')
    的形状缩小。如果
    h'
    w'
    不是单例,仍然需要展平操作。

最终最后一个卷积层的输出形状取决于两件事:输入形状以及其前面的卷积的数量和大小。 上述错误是指 CNN 的输出与线性层期望的形状之间的形状不匹配。

i
是批量大小,
j
实际展平特征长度
k
是第一个线性层的
in_features
l
是其
out_features
。因此,如果您收到此错误,您已经知道要使用哪个
in_features

为了预测此错误并避免在调试架构时抛出它,确定

in_features
的另一种方法是截断模型(删除所有线性层)并使用虚拟数据执行推理。观察该推理的输出形状将告诉您要采用的
in_features

>>> CNNclf().net(torch.rand(3,1,100,100)).shape # adapt with your input shape
torch.Size([3, 64, 7, 7])

因此空间维度为

7x7
,通道数为
64
,因此特征维度为
64*7*7 = 3136
。在这种情况下,第一个线性层必须初始化为
nn.Linear(3136, 20, bias=True)

1.8 版开始,存在一个类

nn.LazyLinear
,它会在运行时(在模型的第一次推理期间)自动推断
in_features
。在这种情况下,不需要自己进行虚拟推理,只需使用
nn.LazyLinear(20, bias=True)

© www.soinside.com 2019 - 2024. All rights reserved.