张量,如何收集索引列表的值?

问题描述 投票:0回答:1
t2 = tf.constant([[0, 11, 2, 3, 4],
                  [5, 61, 7, 8, 9],
                  [10, 11, 12, 13, 14],
                  [15, 16, 17, 18, 19]])
valid_mask = t2 <= 10
validIndex  =  tf.where(valid_mask)
print('validIndex',validIndex) # Expectation = Reality

print()

print('Final Output',tf.gather(t2,indices=validIndex)) # Hmm.. What ?

我的最终输出是 tf.Tensor( [[[ 0 11 2 3 4] [ 0 11 2 3 4]]

[[ 0 11 2 3 4] [10 11 12 13 14]]......

[[10 11 12 13 14] [ 0 11 2 3 4]]], shape=(9, 2, 5), dtype=int32)

期待 [0,2,3,4,5,7,8,9]

  1. 请大家帮忙调试改正
  2. 请解释发生了什么
python tensorflow tensor
1个回答
0
投票

使用

tf.gather_nd
tf.boolean_mask

import tensorflow as tf
t2 = tf.constant([[0, 11, 2, 3, 4],
                  [5, 61, 7, 8, 9],
                  [10, 11, 12, 13, 14],
                  [15, 16, 17, 18, 19]])
valid_mask = t2 <= 10
validIndex  =  tf.where(valid_mask)

print(tf.gather_nd(t2, indices=validIndex))
print(tf.boolean_mask(t2, valid_mask))
tf.Tensor([ 0  2  3  4  5  7  8  9 10], shape=(9,), dtype=int32)
tf.Tensor([ 0  2  3  4  5  7  8  9 10], shape=(9,), dtype=int32)

顺便说一句,根据您的情况,您的预期输出应该包括数字 10。

© www.soinside.com 2019 - 2024. All rights reserved.