RuntimeError:给定 groups=1,权重大小 [16, 2, 3, 3, 3],预期输入 [1, 1024, 1, 512, 512] 有 2 个通道,但得到 1024 个通道 出现这个错误。
def forward(self, fixed, moving):
concat_image = torch.cat((fixed, moving), dim=1) # 2 x 512 x 512
x1 = self.conv1(concat_image) # 16 x 256 x 256
x2 = self.conv2(x1) # 32 x 128 x 128
x3 = self.conv3(x2) # 1 x 64 x 64 x 64
x3_1 = self.conv3_1(x3) # 64 x 64 x 64
x4 = self.conv4(x3_1) # 128 x 32 x 32
x4_1 = self.conv4_1(x4) # 128 x 32 x 32
x5 = self.conv5(x4_1) # 256 x 16 x 16
x5_1 = self.conv5_1(x5) # 256 x 16 x 16
x6 = self.conv6(x5_1) # 512 x 8 x 8
x6_1 = self.conv6_1(x6) # 512 x 8 x 8
pred6 = self.pred6(x6_1) # 2 x 8 x 8
upsamp6to5 = self.upsamp6to5(pred6) # 2 x 16 x 16
deconv5 = self.deconv5(x6_1) # 256 x 16 x 16
concat5 = torch.cat([x5_1, deconv5, upsamp6to5], dim=1) # 514 x 16 x 16
pred5 = self.pred5(concat5) # 2 x 16 x 16
upsamp5to4 = self.upsamp5to4(pred5) # 2 x 32 x 32
deconv4 = self.deconv4(concat5) # 2 x 32 x 32
concat4 = torch.cat([x4_1, deconv4, upsamp5to4], dim=1) # 258 x 32 x 32
pred4 = self.pred4(concat4) # 2 x 32 x 32
upsamp4to3 = self.upsamp4to3(pred4) # 2 x 64 x 64
deconv3 = self.deconv3(concat4) # 64 x 64 x 64
concat3 = torch.cat([x3_1, deconv3, upsamp4to3], dim=1) # 130 x 64 x 64
pred3 = self.pred3(concat3) # 2 x 63 x 64
upsamp3to2 = self.upsamp3to2(pred3) # 2 x 128 x 128
deconv2 = self.deconv2(concat3) # 32 x 128 x 128
concat2 = torch.cat([x2, deconv2, upsamp3to2], dim=1) # 66 x 128 x 128
pred2 = self.pred2(concat2) # 2 x 128 x 128
upsamp2to1 = self.upsamp2to1(pred2) # 2 x 256 x 256
deconv1 = self.deconv1(concat2) # 16 x 256 x 256
concat1 = torch.cat([x1, deconv1, upsamp2to1], dim=1) # 34 x 256 x 256
pred0 = self.pred0(concat1) # 2 x 512 x 512
return pred0 * 20 * self.flow_multiplier