条件 GAN 拟合运行时错误:张量必须具有相同的维数:得到 4 和 1

问题描述 投票:0回答:1

这是我第一次在 PyTorch 中编写/实验条件 GAN 实现。我使用了大量的在线资源来编写一个常规的 GAN,效果非常好。我进行了修改以制作条件 GAN,其关键代码如下:

import torch.nn as nn
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1, inplace=True),

            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1, inplace=True),

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1, inplace=True),

            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1, inplace=True),

            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1, inplace=True),

            nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.1, inplace=True),

            nn.Conv2d(1024, 1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Flatten(),
            nn.Sigmoid()
        )

    def forward(self, x, labels):
        x = torch.cat((x, labels), dim=1)
        return self.main(x)

discriminator = discriminator()
discriminator = to_device(discriminator,device)

class generator(nn.Module):
    def __init__(self, latent_dim):
        super(generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 1024, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.1, inplace=True),

            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1, inplace=True),

            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1, inplace=True),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1, inplace=True),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1, inplace=True),

            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1, inplace=True),

            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, z, labels):
        x = torch.cat((z, labels), dim=1)
        return self.main(x)
   
generator = generator(latent_sz)
generator = to_device(generator,device)

def train_discriminator(real_images, real_labels, opt_d):
    opt_d.zero_grad()

    real_preds = discriminator(real_images, real_labels)
    real_targets = torch.ones(real_images.size(0), 1, device=device)

    real_loss = F.binary_cross_entropy(real_preds, real_targets)
    real_score = torch.mean(real_preds).item()

    latent = torch.randn(batch_size, latent_sz, 1, 1, device=device)
    fake_labels = torch.randint(0, num_classes, (batch_size,), device=device)
    fake_images = generator(latent, fake_labels)

    fake_preds = discriminator(fake_images, fake_labels)
    fake_targets = torch.zeros(fake_images.size(0), 1, device=device)
    fake_loss = F.binary_cross_entropy(fake_preds, fake_targets)
    fake_score = torch.mean(fake_preds).item()

    loss = fake_loss + real_loss
    loss.backward()
    opt_d.step()

    return loss.item(), real_score, fake_score

def train_generator(opt_g):
    opt_g.zero_grad()

    latent = torch.randn(batch_size, latent_sz, 1, 1, device=device)
    fake_labels = torch.randint(0, num_classes, (batch_size,), device=device)
    fake_images = generator(latent, fake_labels)

    targets = torch.ones(batch_size, 1, device=device)
    score = discriminator(fake_images, fake_labels)
    loss = F.binary_cross_entropy(score, targets)

    loss.backward()
    opt_g.step()

    return loss.item()

def fit(epochs, lr, start_idx=1):
    loss_d = []
    loss_g = []
    real_scores = []
    fake_scores = []

    opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))

    for epoch in range(epochs):
        for img, labels in tqdm(train_loader):
            img = img.to(device)
            labels = labels.to(device)

            loss, real_score, fake_score = train_discriminator(img, labels, opt_d)
            lossg = train_generator(opt_g)

        loss_d.append(loss)
        loss_g.append(lossg)
        real_scores.append(real_score)
        fake_scores.append(fake_score)

        print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}, memory_usage: {:.4f}".format(
            epoch + 1, epochs, loss, lossg, real_score, fake_score, psutil.virtual_memory()[2]))

        save_samples(epoch + start_idx, fixed_latent, fixed_labels, show=False)

    return loss_g, loss_d, real_scores, fake_scores

lr = 5e-4
epochs = 20
history = fit(epochs,lr)

我在运行 fit 函数时遇到错误,并且对如何解决它感到困惑。错误如下:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/scratch/ipykernel_151601/1116731599.py in <module>
      1 lr = 5e-4
      2 epochs = 20
----> 3 history = fit(epochs,lr)

/scratch/ipykernel_151601/2486237605.py in fit(epochs, lr, start_idx)
     13             labels = labels.to(device)
     14 
---> 15             loss, real_score, fake_score = train_discriminator(img, labels, opt_d)
     16             lossg = train_generator(opt_g)
     17 

/scratch/ipykernel_151601/2842949976.py in train_discriminator(real_images, real_labels, opt_d)
      2     opt_d.zero_grad()
      3 
----> 4     real_preds = discriminator(real_images, real_labels)
      5     real_targets = torch.ones(real_images.size(0), 1, device=device)
      6 

~/anaconda3/envs/Ashank/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

/scratch/ipykernel_151601/3323902491.py in forward(self, x, labels)
     34 
     35     def forward(self, x, labels):
---> 36         x = torch.cat((x, labels), dim=1)
     37         return self.main(x)
     38 

RuntimeError: Tensors must have same number of dimensions: got 4 and 1

我不知道从这里到哪里去解决这个问题。任何帮助将不胜感激 - 如果还需要任何其他代码,请告诉我。预先感谢您:)。

(P.S.如果有人有用于 256x256 图像生成的 PyTorch 条件 GAN 资源,我也很乐意将其视为参考)

python machine-learning pytorch reshape
1个回答
0
投票

为了连接图像(x)和标签(labels),它们必须具有相同的形状。目前,图像有 4 个维度,而标签只有 1 个维度,这使得串联变得不可能。但是,这个问题有一个解决方案:您可以使用标签的嵌入层。该嵌入层将转换标签,使它们具有与图像相同的形状,从而实现成功的串联。

要实现此方法,您可以参考在线教程:https://learnopencv.com/conditional-gan-cgan-in-pytorch-and-tensorflow/。具体来说,请关注标题为“条件判别器实施”的部分。本教程将指导您完成使用嵌入层来协调标签和图像的形状,从而有效促进它们串联的过程。

© www.soinside.com 2019 - 2024. All rights reserved.