调试 GAN 的鉴别器实现

问题描述 投票:0回答:1

我在将鉴别器包含到

SRGAN
的实现中时遇到问题。在 Flickr 数据集上进行训练时,我发现鉴别器无法尽早学习任何内容(
BCELoss
显示值为 100)并且永远无法恢复。我尝试了一下并删除了 sigmoid,希望使用
BCEWithLogits
作为损失。这导致损失一开始变化很大,最后趋于零。

调试鉴别器实现的好方法是什么?我怀疑我在训练中打电话给鉴别器的方式有问题。

class DiscriminatorConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(DiscriminatorConvBlock, self).__init__()
        num_groups = 8
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 
                                   nn.GroupNorm(num_groups, out_channels),
                                   nn.LeakyReLU(0.2, False),
                                 )
    def forward(self, x):
        out = self.conv1(x)
        return out

class Discriminator(nn.Module):
    def __init__(self, low_res_dim):
        super(Discriminator, self).__init__()
        img_d = int(low_res_dim / 4)
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False), 
                                  nn.LeakyReLU(0.2, False),
                                 )
        self.conv2 = DiscriminatorConvBlock(64, 64, 2)
        self.conv3 = DiscriminatorConvBlock(64, 128, 1)
        self.conv4 = DiscriminatorConvBlock(128, 128, 2)
        self.conv5 = DiscriminatorConvBlock(128, 256, 1)
        self.conv6 = DiscriminatorConvBlock(256, 256, 2)
        self.conv7 = DiscriminatorConvBlock(256, 512, 1)
        self.conv8 = DiscriminatorConvBlock(512, 512, 2)

        self.dense1 = nn.Linear(512 * img_d * img_d , 1024)
        self.leakyRelu = nn.LeakyReLU(0.2, False)
        self.dense2 = nn.Linear(1024 , 1)
        self.drop = nn.Dropout(0.3)
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = self.conv6(out)
        out = self.conv7(out)
        out = self.conv8(out)
        out = out.view(-1, out.size(1) * out.size(2) * out.size(3))
        out = self.leakyRelu(self.dense1(out))
        out = self.dense2(out)
        out = torch.clamp_(out, 0.0, 1.0)
        return out


gen_model = Generator().to(device)
disc_model = Discriminator(low_res).to(device)

# VGG terms
vgg = models.vgg19(pretrained=True).to(device)
feature_nodes = ["features.35"]
feature_extractor = create_feature_extractor(vgg, feature_nodes)
feature_extractor_nodes = feature_nodes
normalizeT = transforms.Normalize([ 0.5, 0.5, 0.5 ], [ 0.5, 0.5, 0.5 ])

for model_parameters in feature_extractor.parameters():
    model_parameters.requires_grad = False
feature_extractor.eval()

gen_optimizer = optim.Adam(gen_model.parameters(),lr=1e-4)
disc_optimizer = optim.Adam(disc_model.parameters(),lr=1e-5)
gen_scheduler = CosineAnnealingWarmRestarts(gen_optimizer, 
                                        T_0 = 8,# Number of iterations for the first restart
                                        T_mult = 1, # A factor increases TiTi​ after a restart
                                        eta_min = 1e-5) # Minimum learning rate
disc_scheduler = CosineAnnealingWarmRestarts(disc_optimizer, 
                                        T_0 = 8,# Number of iterations for the first restart
                                        T_mult = 1, # A factor increases TiTi​ after a restart
                                        eta_min = 1e-6) # Minimum learning rate
mse_loss = nn.MSELoss()
vgg_loss = nn.MSELoss()
disc_loss = nn.BCEWithLogitsLoss()
disc_loss_generator = nn.BCEWithLogitsLoss()
gen_optimizer.zero_grad()
for epoch in range(num_epochs):
    gen_scheduler.step()
    disc_scheduler.step()
    for i, data in enumerate(tqdm.tqdm(dataloader)):
        input_images, labels = data
        # forward pass
        input_images = input_images.to(device)
            
        lowres_images = transforms.Resize(low_res)(input_images)
        gen_highres_images = gen_model(lowres_images.to(device))

        for model_parameters in disc_model.parameters():
            model_parameters.requires_grad = True
        # Discriminator
        disc_model.zero_grad()
        actual_label = disc_model(input_images.to(device))
        # Adversarial loss
        d2_loss = (disc_loss(actual_label, torch.ones_like(actual_label,dtype=torch.float)))
        d2_loss.backward()

        generated_label = disc_model(gen_highres_images.to(device))
        d1_loss = (disc_loss(generated_label, torch.zeros_like(generated_label,dtype=torch.float)))
        d1_loss.backward(retain_graph=True)

        errD = d2_loss + d1_loss
        disc_optimizer.step()

        gen_model.zero_grad() 
        # Perceptual loss
        mse = mse_loss(normalizeT(gen_highres_images), normalizeT(input_images))

        vgg_losses = []
        sr_feature = feature_extractor(normalizeT(input_images))
        gt_feature = feature_extractor(normalizeT(gen_highres_images))
        for i in range(len(feature_extractor_nodes)):
            vgg_losses.append(vgg_loss(sr_feature[feature_extractor_nodes[i]],
                                           gt_feature[feature_extractor_nodes[i]]))

        for model_parameters in disc_model.parameters():
            model_parameters.requires_grad = False
        actual_generated_label = disc_model(gen_highres_images.to(device))
        gen_disc_loss = disc_loss_generator(actual_generated_label, torch.ones_like(actual_label,dtype=torch.float))
        generator_loss = vgg_losses[0] +  mse  + gen_disc_loss
        generator_loss.backward()
        gen_optimizer.step()
        gen_optimizer.zero_grad()
        torch.cuda.empty_cache()
pytorch generative-adversarial-network
1个回答
0
投票

问题是

out = torch.clamp_(out, 0.0, 1.0)
。这对于你想要模型做什么以及你正在使用的损失来说没有意义。

BCEWithLogitsLoss
应用 sigmoid。 sigmoid 函数在您将输出钳制到的范围
[0.5, 1]
上返回
[0, 1]
。您实质上是在强制模型以 >=50% 的置信度预测类别
1
每个示例。

© www.soinside.com 2019 - 2024. All rights reserved.