仅选择张量第二个维度中的一个元素,将形状 (128, 10, 51) 的 TensorFlow 张量更改为形状 (128, 51) 的张量

问题描述 投票:0回答:1

我想通过仅选择张量第二个维度中的一个元素,将形状为 (128, 10, 51) 的 TensorFlow 张量更改为形状为 (128, 51) 的张量。我必须在 ndarray 中选择索引,如下所示:

A 是形状为 (128, 10, 51) 的张量

B 是形状为 (128,) 的 ndarray,元素为 0 到 9

我是通过 for 循环完成的,但我想要一个紧凑的代码来在一/两行内完成此操作。

tensorflow slice tensor
1个回答
0
投票

您可以尝试以下行:

result = tf.gather_nd(A, tf.stack([tf.range(tf.shape(A)[0]), B], axis=1))

tf.gather_nd
函数从 A 张量获取存储在
tf.stack([tf.range(tf.shape(A)[0]), B]
中的切片。

© www.soinside.com 2019 - 2024. All rights reserved.