与输入相比,tf.nn.static_rnn提供不同的输出大小

问题描述 投票:0回答:1

我正在尝试使用tensorflow实现LSTM网络。

我有一个LSTMcell,单位数= 4。

cell = tf.nn.rnn_cell.BasicLSTMCell(4, state_is_tuple=True)

我的输入是一个2 X 44(data_hold)矩阵,我将其拆分为4。

inputs_series = tf.split(data_hold, 4, axis=1)

那么,每个单位的大小必须是2×11对吗?

当我检查我的输入系列时,它是一个形状2 X 11的张量列表,如预期的那样。

[<tf.Tensor 'split:0' shape=(2, 11) dtype=float32>, <tf.Tensor 'split:1' shape=(2, 11) dtype=float32>, <tf.Tensor 'split:2' shape=(2, 11) dtype=float32>, <tf.Tensor 'split:3' shape=(2, 11) dtype=float32>]

但是当我检查state_series时,它是形状2 X 4的张量列表。

# initial_state is initial_state = tf.nn.rnn_cell.LSTMStateTuple(state_data_hold, hidden_data_hold)
state_series, current_step = tf.nn.static_rnn(cell=cell, inputs=inputs_series, initial_state=initial_state)


 # state_series
[<tf.Tensor 'rnn/rnn/basic_lstm_cell/mul_2:0' shape=(2, 4) dtype=float32>, <tf.Tensor 'rnn/rnn/basic_lstm_cell/mul_5:0' shape=(2, 4) dtype=float32>, <tf.Tensor 'rnn/rnn/basic_lstm_cell/mul_8:0' shape=(2, 4) dtype=float32>, <tf.Tensor 'rnn/rnn/basic_lstm_cell/mul_11:0' shape=(2, 4) dtype=float32>]

我的问题是,状态系列和输入系列的形状是否应该相同?

根据static_rnn的文档返回:一对(输出,状态)其中:

输出是输出的长度T列表(每个输入一个),或这些元素的嵌套元组。国家是最终状态

当我打印current_step时,它会返回当前和隐藏单元的元组,但哪个单元是特定的? (有4个吧?)

任何猜测?

python tensorflow lstm recurrent-neural-network
1个回答
1
投票

不应该是状态系列的形状和输入系列相同

我不明白为什么你会这么想。状态形状取决于细胞,而不取决于其输入。所有RNN样细胞的形状由其state_size属性决定,而对于BasicLSTMCell,它是形状[num_units]的两个张量的元组。

© www.soinside.com 2019 - 2024. All rights reserved.