GAN奇怪输出ETA

问题描述 投票:0回答:1

这里是GAN的代码。

# Load the dataset
(X_train, _), (_, _) = mnist.load_data()

# Rescale -1 to 1
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)

设置生成器网络 潜伏暗淡 = 100

generator = Sequential()

generator.add(Dense(256 * 7 * 7, input_dim=latent_dim))
generator.add(Reshape((7, 7, 256)))
generator.add(UpSampling2D())
generator.add(Conv2D(128, kernel_size=3, padding="same"))
generator.add(BatchNormalization(momentum=0.8))
generator.add(LeakyReLU(alpha=0.2))
generator.add(UpSampling2D())
generator.add(Conv2D(64, kernel_size=3, padding="same"))
generator.add(BatchNormalization(momentum=0.8))
generator.add(LeakyReLU(alpha=0.2))
generator.add(Conv2D(1, kernel_size=3, padding="same"))
generator.add(Activation("tanh"))

设置鉴别器网络 鉴别器=顺序()

discriminator.add(Conv2D(32, kernel_size=3, strides=2, input_shape=(28, 28, 1), padding="same"))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dropout(0.25))
discriminator.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
discriminator.add(ZeroPadding2D(padding=((0,1),(0,1))))
discriminator.add(BatchNormalization(momentum=0.8))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dropout(0.25))
discriminator.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
discriminator.add(BatchNormalization(momentum=0.8))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dropout(0.25))
discriminator.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
discriminator.add(BatchNormalization(momentum=0.8))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dropout(0.25))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))

定义 GAN 网络

gan_input = Input(shape=(latent_dim,))
generated_image = generator(gan_input)
gan_output = discriminator(generated_image)

gan = Model(inputs=gan_input, outputs=gan_output)

# Compile the discriminator and GAN networks
discriminator.compile(loss='binary_crossentropy',
                      optimizer=Adam(0.0002, 0.5),
                      metrics=['accuracy'])

gan.compile(loss='binary_crossentropy', optimizer='adam')

训练 GAN

epochs = 10000
batch_size = 512
steps_per_epoch = int(X_train.shape[0] / batch_size)
print(steps_per_epoch)

for epoch in range(epochs):
    # Train the discriminator
    for step in range(steps_per_epoch):
        # Geta batch of real images from the training data
        real_images = X_train[np.random.randint(0, X_train.shape[0], batch_size)]

        # Generate a batch of fake images using the generator
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        fake_images = generator.predict(noise)

        # Train the discriminator on the real and fake images
        discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
        discriminator_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((batch_size, 1)))
        discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)

        # Train the generator
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))

    # Print the progress and save the generated images
    print("Epoch {} Discriminator Loss: {} Generator Loss: {}".format(epoch, discriminator_loss[0], generator_loss))

    if epoch % 100 == 0:
        # Save the generated images
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, latent_dim))
        generated_images = generator.predict(noise)
        generated_images = 0.5 * generated_images + 0.5
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(generated_images[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("gan_mnist_epoch_{}.png".format(epoch))
        plt.close()

这里是输出:

Epoch 21 Discriminator Loss: 19.256301978603005 Generator Loss: 0.010846754536032677
16/16 [==============================] - 0s 22ms/step
16/16 [==============================] - 0s 21ms/step
16/16 [==============================] - 0s 22ms/step
...
16/16 [==============================] - 0s 22ms/step
16/16 [==============================] - 0s 23ms/step
16/16 [==============================] - ETA: 0s

在这样的“ETA:0s”之后,任何其他输出都将停止。代码有效,保存了 100 个时期的图像,但没有输出行。可能这是jupyter notebook的问题。如何解决?

tensorflow keras mnist generative-adversarial-network
1个回答
0
投票

解决方案很简单:缓冲区溢出。我们可以用

替换输出
print("Epoch {} Discriminator Loss: {} Generator Loss: {}".format(epoch, discriminator_loss[0], generator_loss), flush=True)
© www.soinside.com 2019 - 2024. All rights reserved.