在 Keras 中拟合 VAE 时出现 NAN 损失

问题描述 投票:0回答:1

我正在尝试使用 Keras 在 cifar10 图像上构建变分自动编码器。它在 mnist 数据上完美运行。但是对于 cifar10,当我调用 fit 方法时,我的损失(重建损失和 KL 损失)是 NAN,您可以在此处看到: NAN loss

这是我的带有自定义训练步骤的 VAE:

注:cifar10 images shape = (32, 32, 3), latent dimension = 2

class VAE(Model):
  
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)

        # encoder and decoder
        self.encoder = encoder
        self.decoder = decoder

        # losses metrics
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
      with tf.GradientTape() as tape:
        # see 4. Encoder
        z_mu, z_sigma, z = self.encoder(data)
        z_decoded = self.decoder(z)

        # compute the losses
        reconstruction_loss = tf.reduce_mean(
                  tf.reduce_sum(
                      keras.losses.binary_crossentropy(data, z_decoded), axis=(1, 2)
                  )
              )
        kl_loss = -(1 + z_sigma - z_mu**2 - tf.exp(z_sigma)) / 2
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
        total_loss = reconstruction_loss + kl_loss

        # gradients
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # update losses
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        #  return the final losses
        return {
              "loss": self.total_loss_tracker.result(),
              "reconstruction_loss": self.reconstruction_loss_tracker.result(),
              "kl_loss": self.kl_loss_tracker.result(),
          }

我的编码器: encoder graph

我的解码器: decoder graph

有人有想法吗?

tensorflow keras autoencoder loss
1个回答
0
投票

我认为问题出在训练阶跃函数上,这是因为使用了错误的损失函数。将其更改为 categorical_crossentropy 而不是 binary_crossentropy 将起作用。

© www.soinside.com 2019 - 2024. All rights reserved.