tf.gather运行超出范围,同时使用自定义softmax_loss函数,即使它不应该

问题描述 投票:3回答:1

我在tf.contrib.seq2seq.sequence_loss(softmax_loss_function=[...])中使用一个小的自定义函数作为自定义sofmax_loss_function:

    def reduced_softmax_loss(self, labels, logits):
        top_logits, indices = tf.nn.top_k(logits, self.nb_top_classes, sorted=False)
        top_labels = tf.gather(labels, indices)

        return tf.nn.softmax_cross_entropy_with_logits_v2(labels=top_labels,
                                                          logits=top_logits)

但即使标签和logits应该具有相同的维度,执行后它返回和InvalidArgumentError

indices[1500,1] = 2158 is not in [0, 1600)由于我的随机种子而变化。

是否有其他功能,如tf.gather,我可以使用?或者返回值是否为false形状?

如果我通过通常的Tensorflow功能,一切正常。

提前致谢!

python tensorflow softmax seq2seq
1个回答
0
投票

通过查看代码很难分辨出发生了什么,但我不认为你编写的代码符合你的要求。 tf.gather操作需要一个索引输入,其中每个标量值都索引到第一个参数的最外层维度,但是top_k的输出会尝试索引行和列,这会导致超出范围的错误。

© www.soinside.com 2019 - 2024. All rights reserved.