深度学习_我的 U-Net 模型不断出现通道错误不匹配问题

问题描述 投票:0回答:1

U_net Diagram

当我尝试对 x 批次模型运行 y 预测时,它显示不匹配。

Error: --------------------------------------------------------------------------- RuntimeError: Given groups=1, weight of size [512, 1024, 3, 3], expected input[16, 1536, 16, 16] to have 1024 channels, but got 1536 channels instead. 
(火炬)

class UNet(torch.nn.Module):
  def __init__(self):
      super().__init__()

      self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=3, padding='same')
      self.conv2 = torch.nn.Conv2d(64, 64, kernel_size=3, padding='same')
      self.pool1 = torch.nn.MaxPool2d(kernel_size = 2, stride = 2)

      self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=3, padding='same')
      self.conv4 = torch.nn.Conv2d(128, 128, kernel_size=3, padding='same')
      self.pool2 = torch.nn.MaxPool2d(kernel_size = 2, stride = 2)

      self.conv5 = torch.nn.Conv2d(128, 256, kernel_size=3, padding='same')
      self.conv6 = torch.nn.Conv2d(256, 256, kernel_size=3, padding='same')
      self.pool3 = torch.nn.MaxPool2d(kernel_size = 2, stride = 2)

      self.conv7 = torch.nn.Conv2d(256, 512, kernel_size=3, padding='same')
      self.conv8 = torch.nn.Conv2d(512, 512, kernel_size=3, padding='same')
      self.pool4 = torch.nn.MaxPool2d(kernel_size=2, stride = 2)

      self.conv9 = torch.nn.Conv2d(512, 1024, kernel_size=3, padding='same')
      self.conv10 = torch.nn.Conv2d(1024, 1024, kernel_size=3, padding='same')
      self.Upsample = torch.nn.Upsample(scale_factor=2)

      self.conv11 = torch.nn.Conv2d(1024, 512, kernel_size=3, padding='same')
      self.conv12 = torch.nn.Conv2d(1024, 512, kernel_size=3, padding='same')
      self.conv13 = torch.nn.Conv2d(512, 512, kernel_size=3, padding='same')
      self.Upsample = torch.nn.Upsample(scale_factor=2)

      self.conv14 = torch.nn.Conv2d(512, 256, kernel_size=3, padding='same')
      self.conv15 = torch.nn.Conv2d(256, 256, kernel_size=3, padding='same')
      self.Upsample = torch.nn.Upsample(scale_factor=2)

      self.conv16 = torch.nn.Conv2d(256, 128, kernel_size=3, padding='same')
      self.conv17 = torch.nn.Conv2d(128, 128, kernel_size=3, padding='same')
      self.Upsample = torch.nn.Upsample(scale_factor=2)

      self.conv18 = torch.nn.Conv2d(128, 128, kernel_size=3, padding='same')
      self.conv19 = torch.nn.Conv2d(128, 64, kernel_size=3, padding='same')
      self.conv20 = torch.nn.Conv2d(64, 64, kernel_size=3, padding='same')
      self.conv21 = torch.nn.Conv2d(64, 1, kernel_size=1, padding='same')
      self.relu = torch.nn.ReLU()

  def forward(self, x):
    x = self.conv1(x)
    x = self.relu(x)
    x = self.conv2(x)
    x1 = self.relu(x) # We'll use this result later!
    x = self.pool1(x1)

    x = self.conv3(x)
    x = self.relu(x)
    x = self.conv4(x)
    x2 = self.relu(x)
    x = self.pool2(x)

    x = self.conv5(x)
    x = self.relu(x)
    x = self.conv6(x)
    x3 = self.relu(x)
    x = self.pool3(x)

    x = self.conv7(x)
    x = self.relu(x)
    x = self.conv8(x)
    x4 = self.relu(x)
    x = self.pool4(x)

    x = self.conv9(x)
    x = self.relu(x)
    x = self.conv10(x)
    x = self.relu(x)
    x = self.Upsample(x)

    x = torch.cat((x, x4), dim=1)

    x = self.conv11(x)
    x = self.relu(x)
    x = self.conv12(x)
    x = self.relu(x)
    x = self.conv13(x)
    x = self.relu(x)
    x = self.Upsample(x)

    x = torch.cat((x, x3), dim=1)
                  
    x = self.conv14(x)
    x = self.relu(x)
    x = self.conv15(x)
    x = self.relu(x)
    x = self.Upsample(x)

    x = torch.cat((x, x2), dim=1)
      
    x = self.conv16(x)
    x = self.relu(x)
    x = self.conv17(x)

    x = self.relu(x)
    x = self.conv18(x)
    x = self.relu(x)
    x = self.Upsample(x)

    x = torch.cat((x, x1), dim=1)
        
    x = self.conv19(x)
    x = self.relu(x)
    x = self.conv20(x)
    x = self.relu(x)
    x = self.conv21(x)
    x = self.relu(x)
    #x = self.sigmoid(x)

    return x

我想知道我的某些图层是否错误。我试图将其与图表完全匹配,但也许我添加了额外的层,或者通道顺序不正确。我也确保连接(图中 4 个灰色箭头)。

deep-learning conv-neural-network unet-neural-network
1个回答
0
投票

您计算的 conv11 层的通道数错误。 在第一次上采样操作之后,您将具有 1024 个通道的上采样张量 x 与具有 512 个通道的 x4 连接起来。算一下,你会得到 1024+512=1536。所以你应该将你的 conv11 更改为

      self.conv11 = torch.nn.Conv2d(1536, 512, kernel_size=3, padding='same')

您还需要更改这些图层:

self.conv12 = torch.nn.Conv2d(512, 512, kernel_size=3, padding='same')
self.conv14 = torch.nn.Conv2d(768, 256, kernel_size=3, padding='same')
self.conv16 = torch.nn.Conv2d(384, 128, kernel_size=3, padding='same')
self.conv19 = torch.nn.Conv2d(192, 64, kernel_size=3, padding='same')
© www.soinside.com 2019 - 2024. All rights reserved.