在TensorFlow中使用MultiRNNCell的随机初始状态

问题描述 投票:1回答:1

我有一个以这种方式创建的MultiRNN单元

def get_cell(cell_type, num_units, training):
    if cell_type == "RNN":
        cell = tf.contrib.rnn.BasicRNNCell(num_units)
    elif cell_type == "LSTM":
        cell = tf.contrib.rnn.BasicLSTMCell(num_units)
    else:
        cell = tf.contrib.rnn.GRUCell(num_units)

    if training:
        cell = tf.contrib.rnn.DropoutWrapper(cell,
                                input_keep_prob=params["dropout_input_keep_prob"],
                                output_keep_prob=params["dropout_output_keep_prob"],
                                state_keep_prob=params["dropout_state_keep_prob"])

    return cell

final_cell_structure = tf.contrib.rnn.MultiRNNCell([get_cell(cell_type, num_units, (mode == tf.estimator.ModeKeys.TRAIN)) for _ in range(num_layers)])

我试图将其状态初始化为随机值。我试过这样做:

initial_state = state = final_cell_structure.zero_state(batch_size, tf.float32)
if mode == tf.estimator.ModeKeys.PREDICT:
    state = state + tf.random_normal(shape=tf.shape(state), mean=0.0, stddev=0.6)

但我不断收到错误消息

Expected state to be a tuple of length 3, but received: Tensor("Reshape:0", shape=(3, 1, 10), dtype=float32)

当我使用它

output, state = final_cell_structure(inputs, state)

更新我尝试使用

state = [st + tf.random_normal(shape=tf.shape(st), mean=0.0, stddev=0.6) for st in state]

正如Pop所建议的,它适用于基本RNN单元格和GRU单元格,但当我将它与LSTM单元格一起使用时,我会收到以下错误

Tensor objects are not iterable when eager execution is not enabled. To iterate over this tensor use tf.map_fn

求解LSTM单元状态由元组组成,所以我发现这个解决方案有效

state_placeholder = tf.random_normal(shape=(num_layers, 2, batch_size, num_units), mean=0.0, stddev=1.0)
l = tf.unstack(state_placeholder, axis=0)
state = tuple([tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1]) for idx in range(num_layers)])
python tensorflow state rnn
1个回答
0
投票

这个想法是state是一个元组。

所以你需要以这种方式更新它:

state = [st + tf.random_normal(shape=tf.shape(st), mean=0.0, stddev=0.6) for st in state]

它应该工作。

使用您的方法,您创建了单个张量f形状(2,b,k)而不是具有相同大小(b,k)的两个张量的元组

© www.soinside.com 2019 - 2024. All rights reserved.