我在文档和文章的帮助下实现了原始 Unet 模型。并调整其超参数以提供最佳结果(dice_loss 和准确度)。现在我正在尝试实现 GAN,但我得到了非常荒谬的结果,例如非常高的生成器损失和非常低的鉴别器损失。我认为我在实现生成器时犯了错误,或者我错误地实现了 GAN 的训练循环。我应该如何弄清楚并修改下面的 GAN 训练循环?
# Define hyperparameters
epochs = 50 # Number of training epochs
batch_size = 32 # Batch size
discriminator_lr = 0.0002 # Learning rate for discriminator
generator_lr = 0.0002 # Learning rate for generator
# Define optimizers
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=discriminator_lr)
generator_optimizer = tf.keras.optimizers.Adam(learning_rate=generator_lr)
# Training loop
for epoch in range(epochs):
for batch_real_images, batch_masks in dataset:
# Generate fake masks using the generator
# print(batch_real_images.shape)
batch_fake_masks = generator.predict(batch_real_images)
batch_size = tf.shape(batch_real_images)[0]
# Combine real and fake images and masks
combined_images = tf.concat([batch_real_images, batch_real_images], axis=0)
combined_masks = tf.concat([batch_masks[:batch_size], batch_fake_masks], axis=0)
# print(i,combined_masks.shape)
# i=i+1
# Create labels for discriminator (1 for real, 0 for fake)
real_labels = tf.ones((batch_size, 1))
fake_labels = tf.zeros((batch_size, 1))
# Concatenate real and fake labels for both real and fake samples
labels = tf.concat([real_labels, fake_labels], axis=0)
#By making this change, you'll ensure that the labels have the correct shape and represent the real and fake samples appropriately.
#Make this adjustment in your code, and it should resolve the shape mismatch issue.
# print(labels.shape)
# break
# Train the discriminator
with tf.GradientTape() as disc_tape:
bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
discriminator_loss = bce(labels, discriminator(combined_masks))
discriminator_gradients = disc_tape.gradient(discriminator_loss, discriminator.trainable_variables)
discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))
# Train the generator
with tf.GradientTape() as gen_tape:
generated_masks = generator(batch_real_images)
discriminator_output = discriminator(generated_masks)
generator_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(tf.ones_like(discriminator_output), discriminator_output)
generator_gradients = gen_tape.gradient(generator_loss, generator.trainable_variables)
generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
print(f"Epoch {epoch + 1}/{epochs}, Discriminator Loss: {discriminator_loss}, Generator Loss: {generator_loss}")
当尝试实现或了解 GAN 模型的实现时,我强烈建议您查看可用的高质量研究出版物代码。掌握 GAN 的训练艺术绝非易事。因此,深入研究较小的 GAN 模型的实现并适应现有的代码库至关重要。此外,全面理解损失函数的正确配置对于这一努力取得成功至关重要。
看一下: