如何在应用 conv2d 和 convtranspose2d 步骤后保持输入和输出形状一致?

问题描述 投票:0回答:0

美好的一天!

我正在使用 Pytorch 来试验图像分割作业。使用 stride=2,我应用 conv2d 进行下采样,然后应用 convtranspose2d 进行上采样。我发现输出形状通常与原始输入形状不一致。对于下面的示例,原始图像形状为 (33,34),返回的形状变为 (48,48)。我认为原因是图像高度或宽度上的像素数量奇数,但如何解决它以在整个过程中保持一致?

致以诚挚的问候

import torch
import torch.nn.functional as F

class FCN(torch.nn.Module):

    class Block(torch.nn.Module):
        def __init__(self, n_input, n_output, kernel_size=3, stride=2):
            super().__init__()
            self.c1 = torch.nn.Conv2d(n_input, n_output, kernel_size=kernel_size, padding=kernel_size // 2,
                                      stride=stride, bias=False)

            self.b1 = torch.nn.BatchNorm2d(n_output)

        def forward(self, x):
            return F.relu(self.b1(self.c1(x)))

    class UpBlock(torch.nn.Module):
        def __init__(self, n_input, n_output, kernel_size=3, stride=2):
            super().__init__()
            self.c1 = torch.nn.ConvTranspose2d(n_input, n_output, kernel_size=kernel_size, padding=kernel_size // 2, stride=stride, output_padding=1)
        def forward(self, x):
            return F.relu(self.c1(x))

    def __init__(self, n_output_channels=5, kernel_size=3):
        super().__init__()

        self.add_module('conv0', self.Block(3, 16, kernel_size, 2))
        self.add_module('conv1', self.Block(16, 32, kernel_size, 2))
        self.add_module('conv2', self.Block(32, 64, kernel_size, 2))
        self.add_module('conv3', self.Block(64, 128, kernel_size, 2))
        self.add_module("upconv0", self.UpBlock(128, 64, kernel_size, 2))
        self.add_module("upconv1", self.UpBlock(64, 32, kernel_size, 2))
        self.add_module("upconv2", self.UpBlock(32, 16, kernel_size, 2))
        self.add_module("upconv3", self.UpBlock(16, 8, kernel_size, 2))

        self.classifier = torch.nn.Conv2d(8, n_output_channels, 1)

    def forward(self, x):
        x = self._modules['conv0'](x)
        print(x.shape)
        x = self._modules['conv1'](x)
        print(x.shape)
        x = self._modules['conv2'](x)
        print(x.shape)
        x = self._modules['conv3'](x)
        print(x.shape)
        x = self._modules['upconv0'](x)
        print(x.shape)
        x = self._modules['upconv1'](x)
        print(x.shape)
        x = self._modules['upconv2'](x)
        print(x.shape)
        x = self._modules['upconv3'](x)
        print(x.shape)
        return self.classifier(x)


FCN()(torch.zeros(1,3,33,34))
machine-learning deep-learning pytorch computer-vision image-segmentation
© www.soinside.com 2019 - 2024. All rights reserved.