我想通过仅选择张量第二个维度中的一个元素,将形状为 (128, 10, 51) 的 TensorFlow 张量更改为形状为 (128, 51) 的张量。我必须在 ndarray 中选择索引,如下所示:
A 是形状为 (128, 10, 51) 的张量
B 是形状为 (128,) 的 ndarray,元素为 0 到 9
我是通过 for 循环完成的,但我想要一个紧凑的代码来在一/两行内完成此操作。
您可以尝试以下行:
result = tf.gather_nd(A, tf.stack([tf.range(tf.shape(A)[0]), B], axis=1))
tf.gather_nd
函数从 A 张量获取存储在 tf.stack([tf.range(tf.shape(A)[0]), B]
中的切片。