使用注意和beamsearch实现具有双向multilstm层的seq2seq模型。 (只发布必要的代码以保持简单)
# helper to create the layers
def make_lstm(rnn_size, keep_prob):
lstm = tf.nn.rnn_cell.LSTMCell(rnn_size, initializer = tf.random_uniform_initializer(-0.1, 0.1, seed=2))
lstm_dropout = tf.nn.rnn_cell.DropoutWrapper(lstm, input_keep_prob = keep_prob)
return lstm_dropout
# helper to create the attention cell with
def decoder_cell(dec_cell, rnn_size, enc_output, lengths):
attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
num_units = rnn_size,
memory = enc_output,
memory_sequence_length = lengths,
normalize = True,
name = 'BahdanauAttention')
return tf.contrib.seq2seq.AttentionWrapper(
cell = dec_cell,
attention_mechanism = attention_mechanism,
attention_layer_size = rnn_size)
# foward
cell_fw = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob) for _ in range(n_layers)])
# backward
cell_bw = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob) for _ in range(n_layers)])
enc_output, enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,
cell_bw,
rnn_inputs,
sequence_length=sequence_length,
dtype=tf.float32,
)
enc_output = tf.concat(enc_output,-1)
beam_width = 10
dec_cell = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob) for _ in range(num_layers)])
output_layer = Dense(vocab_size, kernel_initializer = tf.truncated_normal_initializer(mean = 0.0, stddev=0.1))
dec_cell = decoder_cell(dec_cell, rnn_size, enc_output, text_length)
with tf.variable_scope("decode"):
# (dec_embed_input comes from another function but should not be
# relevant in this context. )
helper = tf.contrib.seq2seq.TrainingHelper(inputs = dec_embed_input,
sequence_length = summary_length,
time_major = False)
decoder = tf.contrib.seq2seq.BasicDecoder(cell = dec_cell,
helper = helper,
initial_state = dec_cell.zero_state(batch_size, tf.float32),
output_layer = output_layer)
logits = tf.contrib.seq2seq.dynamic_decode(decoder=decoder,
output_time_major=False,
impute_finished=True,
maximum_iterations=max_summary_length)
enc_output = tf.contrib.seq2seq.tile_batch(enc_output, multiplier=beam_width)
enc_state = tf.contrib.seq2seq.tile_batch(enc_state, multiplier=beam_width)
text_length = tf.contrib.seq2seq.tile_batch(text_length, multiplier=beam_width)
dec_cell = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob) for _ in range(num_layers)])
dec_cell = decoder_cell(dec_cell, rnn_size, enc_output, text_length)
start_tokens = tf.tile(tf.constant([word2ind['<GO>']], dtype = tf.int32), [batch_size], name = 'start_tokens')
with tf.variable_scope("decode", reuse = True):
decoder = tf.contrib.seq2seq.BeamSearchDecoder( cell=dec_cell,
embedding=embeddings,
start_tokens=start_tokens,
end_token=end_token,
initial_state=dec_cell.zero_state(batch_size = batch_size*beam_width , dtype = tf.float32),
beam_width=beam_width,
output_layer=output_layer,
length_penalty_weight=0.0)
logits = tf.contrib.seq2seq.dynamic_decode(decoder=decoder,
output_time_major=False,
impute_finished=True,
maximum_iterations=max_summary_length)
在这一行:
decoder = tf.contrib.seq2seq.BeamSearchDecoder( cell=dec_cell,
embedding=embeddings,
start_tokens=start_tokens,
end_token=end_token,
initial_state=dec_cell.zero_state(batch_size = batch_size*beam_width , dtype = tf.float32),
beam_width=beam_width,
output_layer=output_layer,
length_penalty_weight=0.0)
我收到以下错误:
ValueError: Shapes must be equal rank, but are 3 and 2 for 'decode_1/decoder/while/Select_4' (op: 'Select') with input shapes: [64,10], [64,10,256], [64,10,256].
有没有人有这方面的经验,或遇到同样的问题?我真的很感谢你的建议。
Tensorflow:1.6.0 batch_size = 64 rnn_size = 256
确保你将impute_finished=False
传递给dynamic_decode()
。
我想你需要设置解码器init_state = encoder_state
enc_output = tf.contrib.seq2seq.tile_batch(enc_output, multiplier=beam_width)
num_bi_layes = int(num_layers/2)
if num_bi_layes == 1:
encoder_state = enc_state
else:
encoder_state = []
for layer_id in range(num_bi_layes):
encoder_state.append(enc_state[0][layer_id]) #forward
encoder_state.append(enc_state[1][layer_id]) #backward
encoder_state = touple(encoder_state)
encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=beam_width)
text_length = tf.contrib.seq2seq.tile_batch(text_length, multiplier=beam_width)
dec_cell = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob) for _ in range(num_layers)])
dec_cell = decoder_cell(dec_cell, rnn_size, enc_output, text_length)
start_tokens = tf.tile(tf.constant([word2ind['<GO>']], dtype = tf.int32), [batch_size], name = 'start_tokens')
with tf.variable_scope("decode", reuse = True):
decoder = tf.contrib.seq2seq.BeamSearchDecoder( cell=dec_cell,
embedding=embeddings,
start_tokens=start_tokens,
end_token=end_token,
initial_state=encoder_state,
beam_width=beam_width,
output_layer=output_layer,
length_penalty_weight=0.0)