python运行代码出现问题

问题描述 投票:0回答:0

RuntimeError:给定 groups=1,权重大小 [16, 2, 3, 3, 3],预期输入 [1, 1024, 1, 512, 512] 有 2 个通道,但得到 1024 个通道 出现这个错误。

def forward(self, fixed, moving):
    concat_image = torch.cat((fixed, moving), dim=1)  # 2 x 512 x 512
    x1 = self.conv1(concat_image)  # 16 x 256 x 256
    x2 = self.conv2(x1)  # 32 x 128 x 128
    x3 = self.conv3(x2)  # 1 x 64 x 64 x 64
    x3_1 = self.conv3_1(x3)  # 64 x 64 x 64
    x4 = self.conv4(x3_1)  # 128 x 32 x 32
    x4_1 = self.conv4_1(x4)  # 128 x 32 x 32
    x5 = self.conv5(x4_1)  # 256 x 16 x 16
    x5_1 = self.conv5_1(x5)  # 256 x 16 x 16
    x6 = self.conv6(x5_1)  # 512 x 8 x 8
    x6_1 = self.conv6_1(x6)  # 512 x 8 x 8

    pred6 = self.pred6(x6_1)  # 2 x 8 x 8
    upsamp6to5 = self.upsamp6to5(pred6)  # 2 x 16 x 16
    deconv5 = self.deconv5(x6_1)  # 256 x 16 x 16
    concat5 = torch.cat([x5_1, deconv5, upsamp6to5], dim=1)  # 514 x 16 x 16

    pred5 = self.pred5(concat5)  # 2 x 16 x 16
    upsamp5to4 = self.upsamp5to4(pred5)  # 2 x 32 x 32
    deconv4 = self.deconv4(concat5)  # 2 x 32 x 32
    concat4 = torch.cat([x4_1, deconv4, upsamp5to4], dim=1)  # 258 x 32 x 32

    pred4 = self.pred4(concat4)  # 2 x 32 x 32
    upsamp4to3 = self.upsamp4to3(pred4)  # 2 x 64 x 64
    deconv3 = self.deconv3(concat4)  # 64 x 64 x 64
    concat3 = torch.cat([x3_1, deconv3, upsamp4to3], dim=1)  # 130 x 64 x 64

    pred3 = self.pred3(concat3)  # 2 x 63 x 64
    upsamp3to2 = self.upsamp3to2(pred3)  # 2 x 128 x 128
    deconv2 = self.deconv2(concat3)  # 32 x 128 x 128
    concat2 = torch.cat([x2, deconv2, upsamp3to2], dim=1)  # 66 x 128 x 128

    pred2 = self.pred2(concat2)  # 2 x 128 x 128
    upsamp2to1 = self.upsamp2to1(pred2)  # 2 x 256 x 256
    deconv1 = self.deconv1(concat2)  # 16 x 256 x 256
    concat1 = torch.cat([x1, deconv1, upsamp2to1], dim=1)  # 34 x 256 x 256

    pred0 = self.pred0(concat1)  # 2 x 512 x 512

    return pred0 * 20 * self.flow_multiplier
image registration medical
© www.soinside.com 2019 - 2024. All rights reserved.