这是我第一次在 PyTorch 中编写/实验条件 GAN 实现。我使用了大量的在线资源来编写一个常规的 GAN,效果非常好。我进行了修改以制作条件 GAN,其关键代码如下:
import torch.nn as nn
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(1024),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv2d(1024, 1, kernel_size=4, stride=1, padding=0, bias=False),
nn.Flatten(),
nn.Sigmoid()
)
def forward(self, x, labels):
x = torch.cat((x, labels), dim=1)
return self.main(x)
discriminator = discriminator()
discriminator = to_device(discriminator,device)
class generator(nn.Module):
def __init__(self, latent_dim):
super(generator, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(latent_dim, 1024, kernel_size=4, stride=1, padding=0, bias=False),
nn.BatchNorm2d(1024),
nn.LeakyReLU(0.1, inplace=True),
nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.1, inplace=True),
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.1, inplace=True),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.1, inplace=True),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.1, inplace=True),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.1, inplace=True),
nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh()
)
def forward(self, z, labels):
x = torch.cat((z, labels), dim=1)
return self.main(x)
generator = generator(latent_sz)
generator = to_device(generator,device)
def train_discriminator(real_images, real_labels, opt_d):
opt_d.zero_grad()
real_preds = discriminator(real_images, real_labels)
real_targets = torch.ones(real_images.size(0), 1, device=device)
real_loss = F.binary_cross_entropy(real_preds, real_targets)
real_score = torch.mean(real_preds).item()
latent = torch.randn(batch_size, latent_sz, 1, 1, device=device)
fake_labels = torch.randint(0, num_classes, (batch_size,), device=device)
fake_images = generator(latent, fake_labels)
fake_preds = discriminator(fake_images, fake_labels)
fake_targets = torch.zeros(fake_images.size(0), 1, device=device)
fake_loss = F.binary_cross_entropy(fake_preds, fake_targets)
fake_score = torch.mean(fake_preds).item()
loss = fake_loss + real_loss
loss.backward()
opt_d.step()
return loss.item(), real_score, fake_score
def train_generator(opt_g):
opt_g.zero_grad()
latent = torch.randn(batch_size, latent_sz, 1, 1, device=device)
fake_labels = torch.randint(0, num_classes, (batch_size,), device=device)
fake_images = generator(latent, fake_labels)
targets = torch.ones(batch_size, 1, device=device)
score = discriminator(fake_images, fake_labels)
loss = F.binary_cross_entropy(score, targets)
loss.backward()
opt_g.step()
return loss.item()
def fit(epochs, lr, start_idx=1):
loss_d = []
loss_g = []
real_scores = []
fake_scores = []
opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
opt_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
for epoch in range(epochs):
for img, labels in tqdm(train_loader):
img = img.to(device)
labels = labels.to(device)
loss, real_score, fake_score = train_discriminator(img, labels, opt_d)
lossg = train_generator(opt_g)
loss_d.append(loss)
loss_g.append(lossg)
real_scores.append(real_score)
fake_scores.append(fake_score)
print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}, memory_usage: {:.4f}".format(
epoch + 1, epochs, loss, lossg, real_score, fake_score, psutil.virtual_memory()[2]))
save_samples(epoch + start_idx, fixed_latent, fixed_labels, show=False)
return loss_g, loss_d, real_scores, fake_scores
lr = 5e-4
epochs = 20
history = fit(epochs,lr)
我在运行 fit 函数时遇到错误,并且对如何解决它感到困惑。错误如下:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/scratch/ipykernel_151601/1116731599.py in <module>
1 lr = 5e-4
2 epochs = 20
----> 3 history = fit(epochs,lr)
/scratch/ipykernel_151601/2486237605.py in fit(epochs, lr, start_idx)
13 labels = labels.to(device)
14
---> 15 loss, real_score, fake_score = train_discriminator(img, labels, opt_d)
16 lossg = train_generator(opt_g)
17
/scratch/ipykernel_151601/2842949976.py in train_discriminator(real_images, real_labels, opt_d)
2 opt_d.zero_grad()
3
----> 4 real_preds = discriminator(real_images, real_labels)
5 real_targets = torch.ones(real_images.size(0), 1, device=device)
6
~/anaconda3/envs/Ashank/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
/scratch/ipykernel_151601/3323902491.py in forward(self, x, labels)
34
35 def forward(self, x, labels):
---> 36 x = torch.cat((x, labels), dim=1)
37 return self.main(x)
38
RuntimeError: Tensors must have same number of dimensions: got 4 and 1
我不知道从这里到哪里去解决这个问题。任何帮助将不胜感激 - 如果还需要任何其他代码,请告诉我。预先感谢您:)。
(P.S.如果有人有用于 256x256 图像生成的 PyTorch 条件 GAN 资源,我也很乐意将其视为参考)
为了连接图像(x)和标签(labels),它们必须具有相同的形状。目前,图像有 4 个维度,而标签只有 1 个维度,这使得串联变得不可能。但是,这个问题有一个解决方案:您可以使用标签的嵌入层。该嵌入层将转换标签,使它们具有与图像相同的形状,从而实现成功的串联。
要实现此方法,您可以参考在线教程:https://learnopencv.com/conditional-gan-cgan-in-pytorch-and-tensorflow/。具体来说,请关注标题为“条件判别器实施”的部分。本教程将指导您完成使用嵌入层来协调标签和图像的形状,从而有效促进它们串联的过程。