tensorflow - ValueError:解码器/while/Merge_12:0 的形状不是循环的不变量

问题描述 投票:0回答:1

我使用 tf.contrib.seq2seq.dynamic_decode 进行解码器训练

prediction, final_decoder_state, _ = dynamic_decode(
    custom_decoder
)

带有自定义解码器

custom_decoder = CustomDecoder(decoder_cell, helper, decoder_init_state)

还有帮手

helper = CustomTrainingHelper(batch_size, targets, stop_targets,
                              num_outs, outputs_per_step, 1.0, False)

dynamic_decoder 引发错误

Traceback (most recent call last):
  File "E:/tasks/text_to_speech/tts/tf_seq2seq.py", line 95, in <module>
    custom_decoder
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\contrib\seq2seq\python\ops\decoder.py", line 304, in dynamic_decode
    swap_memory=swap_memory)
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3224, in while_loop
    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2956, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2930, in _BuildLoop
    next_vars.append(_AddNextAndBackEdge(m, v))
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 688, in _AddNextAndBackEdge
    _EnforceShapeInvariant(m, v)
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 632, in _EnforceShapeInvariant
    (merge_var.name, m_shape, n_shape))
ValueError: The shape for decoder/while/Merge_12:0 is not an invariant for the loop. It enters the loop with shape (10, 1), but has shape (?, 1) after one iteration. Provide shape invariants using either the `shape_invariants` argument of tf.while_loop or set_shape() on the loop variables.

batch_size 等于 10。据我了解,问题出在 tf.while_loop 和 batch_size 中。有什么方法可以修复这个错误?预先感谢。

tensorflow dynamic while-loop decode
1个回答
1
投票

一般来说,此错误告诉您以下内容。默认情况下,TensorFlow 检查从 while 循环的一次迭代传递到下一次迭代的变量是否不会改变形状。在您的例子中,

decoder/while/Merge_12:0 
张量最初的形状为
(10, 1)
,但在一次迭代后它变成了
(?, 1)
,这意味着张量流无法再推断第一维的大小。

如果您知道第一个维度确实是

10
,您可以使用 Tensor.set_shape 将其告知 TensorFlow。

© www.soinside.com 2019 - 2024. All rights reserved.