将已弃用的dynamic_rnn转换为TensorFlow 2.0

问题描述 投票:0回答:1

我最近正在使用

tf.compat.v1.nn.dynamic_rnn
作为其中一种方法来重新开发一个旧的应用程序。整个代码块是:

all_lstm_outputs, self.state = tf.compat.v1.nn.dynamic_rnn(
        drop_multi_cell, all_inputs, initial_state=tuple(initial_state),
        time_major = True, dtype=tf.float32)

但是,该版本不再适用于2.0,我希望将代码转换为

tf.keras.layers.RNN

对于其他上下文,

drop_multi_cell = tf.keras.layers.StackedRNNCells(drop_lstm_cells)
,但我很难看出代码如何适合最新的模块。

我正在测试:

rnn_layer = tf.keras.layers.RNN([drop_multi_cell],
                                time_major=True)
        all_lstm_outputs, self.state = rnn_layer(all_inputs, initial_state=tuple(initial_state))

但看起来不对。

对于其他情况,

https://www.tensorflow.org/api_docs/python/tf/compat/v1/nn/dynamic_rnn https://www.tensorflow.org/api_docs/python/tf/keras/layers/RNN

python tensorflow tensorflow2.0
1个回答
0
投票

在 TensorFlow 2.x 中,tf.keras.layers.RNN 用于为模型创建循环层。您走在正确的轨道上,但需要进行一些调整才能正确地将 tf.keras.layers.RNN 与 drop_multi_cell 一起使用。

请按照以下步骤解决问题:

首先需要导入Tensorflow。

使用 tf.keras.layers.StackedRNNCells 定义 drop_lstm_cells 并创建 drop_multi_cell。

使用 tf.keras.layers.RNN 创建 rnn_layer,直接传递 drop_multi_cell。

并使用带有输入 all_inputs 和初始状态initial_state 的rnn_layer。

现在,这应该可以正确地将您的代码转换为在 TensorFlow 2.x 中使用 tf.keras.layers.RNN。 rnn_layer 将具有 return_sequences=True 和 return_state=True 以确保您获得输出和最终状态。

time_major=True 参数将输入和输出张量保持为时间主格式,这似乎是原始 tf.compat.v1.nn.dynamic_rnn 调用中的内容。

© www.soinside.com 2019 - 2024. All rights reserved.