具有多个 batchnorm1d 层的 Pytorch 模型在推理过程中出现错误 - “预期的 2D 或 3D 输入(获得 1D 输入)”

问题描述 投票:0回答:1

我在pytorch中使用batchnorm1D的模型是这样的:

class Discriminator(nn.Module):
def __init__(self, sequenceLength):
    super(Discriminator,self).__init__()
    self.batchnorm1 = nn.BatchNorm1d(sequenceLength)
    self.batchnorm2 = nn.BatchNorm1d(2*sequenceLength)
    self.linear1 = nn.Linear(sequenceLength, 2*sequenceLength)
    self.conv2 = nn.Conv1d(1, 1,kernel_size=3, stride=1, padding=1)
    self.conv3 = nn.Conv1d(1, 1,kernel_size=3, stride=1, padding=1)
    self.linear4 = nn.Linear(2*sequenceLength, 1)
    self.relu = nn.ReLU(0.01)
    self.sigmoid = nn.Sigmoid()

def forward(self, x):
    out = self.batchnorm1(x)
    out = self.linear1(out)
    out = self.relu(out)
    out = self.batchnorm2(out)
    out = out.unsqueeze(1)
    out = self.conv2(out)
    out = self.sigmoid(out)
    out = self.conv3(out)
    out = self.relu(out)
    out = out.squeeze()
    out = self.batchnorm2(out)
    out = self.linear4(out)
    out = self.sigmoid(out)
    return out

我的推理代码是这样的:

Discriminator = torch.load('disc.pth', map_location=torch.device('cpu'))
Discriminator.eval()
embededSeq = Embedding.EmbedOne('sample data')
embededSeq = torch.tensor(embededSeq).float()
embededSeq = embededSeq.unsqueeze(0)
score = PosDiscriminator(embededSeq).detach().numpy()[0]

我在模型中的

out = self.batchnorm2(out) 
行收到错误消息:“预期 2D 或 3D 输入(获得 1D 输入)”。 不知道是不是我之前那行
out = out.squeeze()
造成的? 但是,训练代码运行良好,仅在推理过程中发生。

您能看一下并告诉我出了什么问题吗?

提前谢谢您,

pytorch inference batchnorm
1个回答
0
投票

是的,这个问题是由

out.squeeze
引起的。一般来说,您应该避免在没有任何输入的情况下使用
out.squeeze()
,因为这会删除所有大小为 0 或 1 的维度,从而导致维数不确定的张量。在模型中,每层期望的维度数几乎总是固定的,因此这可能会导致问题。

我们首先考虑一个二维的训练批次(batch_size,length):

  • batchnorm1 输出是 2 维的。

  • 线性1输出是二维的

  • unsqueeze 输出为 3 维 [batch_size,1,length]

  • conv2 和 sigmoid 输出是 3 维的

  • conv3 和 relu 输出是 3 维的

  • 在这里你调用out.squeeze()。在训练情况下,

    batch_size
    大于1,因此维度0不被压缩,仅移除维度1,产生大小为[batch_size,length]的张量。

  • 在推理案例中,您的

    batch_size
    为 1,因此维度 0 和 1 都被压缩。结果是形状为 [length] 的张量。

  • 下一个batchnorm层需要2D或3D输入(batch_size,channels,length)或(batch_size,length)。在批量大小为 1(推理)的情况下,输入违反了此期望,从而导致错误。

© www.soinside.com 2019 - 2024. All rights reserved.