有没有另一种loss可以替代tensorflow中的seq2seq.sequence_loss

问题描述 投票:0回答:1

我正在运行 CVAE 来生成文本。我使用的是tensorflow > 2.0。

问题是,对于我的损失,我使用 seq2seq.sequence_loss。自从 6 年前发布代码以来,我尝试将 TensorFlow v1 更新为 v2。

下面是我的代码片段:

class CustomVariationalLayer(Layer):
    def __init__(self, **kwargs):
        self.is_placeholder = True
        super(CustomVariationalLayer, self).__init__(**kwargs)
        self.target_weights = tf.constant(np.ones((batch_size, max_len)), tf.float32)

def vae_loss(self, x, x_decoded_mean):
    labels = tf.cast(x, tf.int32)
    xent_loss = K.sum(tf.contrib.seq2seq.sequence_loss(x_decoded_mean, labels, 
                                                 weights=self.target_weights,
                                                 average_across_timesteps=False,
                                                 average_across_batch=False), axis=-1)#,
                                                 #softmax_loss_function=softmax_loss_f), axis=-1)#,
    kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
    xent_loss = K.mean(xent_loss)
    kl_loss = K.mean(kl_loss)
    return K.mean(xent_loss + kl_weight * kl_loss)

运行脚本后出现此错误:

文件“/dm4i/work/VAE-testbyme/Textvae_v2.py”,第 193 行,调用中* 损失 = self.vae_loss(x, x_decoded_mean) 文件“/dm4i/work/VAE-testbyme/Textvae_v2.py”,第 179 行,vae_loss * xent_loss = K.sum(tf.contrib.seq2seq.sequence_loss(x_decoded_mean, labels,

AttributeError: module 'tensorflow' has no attribute 'contrib'

有什么提示吗?或者提示?

非常感谢

tensorflow tensorflow2.0 tf.keras tensorflow-serving
1个回答
0
投票

一种可能性是使用

tfa.seq2seq.sequence_loss
,它基本上是
tf.contrib.seq2seq.sequence_loss
到 TensorFlow 2 的端口。您可以通过安装 TensorFlow Addons 包来获取它。但这并不理想,因为 Tensorflow Addons (
tfa
) 也已弃用,并将在 2024 年 5 月后停止接受任何维护。

© www.soinside.com 2019 - 2024. All rights reserved.