两个不同长度的张量之间的相交

问题描述 投票:3回答:1

我有一个张量流的情况。我想找到两个具有不同形状的二维张量的交点。

例:

object_ids_  [[0 0]
              [0 1]
              [1 1]]

object_ids_more_07_  [[0 0]
                      [0 1]
                      [0 2]
                      [1 0]
                      [1 2]]

我正在寻找的输出是:

[[0,0], 
 [0,1]]

我遇到了“tf.sets.set_intersection”,张量流页面:https://www.tensorflow.org/api_docs/python/tf/sets/set_intersection

但是不能为具有不同形状的张量执行它。我找到的另一个实现是:

Find the intersection of two tensors. Return the sorted, unique values that are in both of the input tensors

但是很难将它复制到2D张量中。

任何帮助将不胜感激,谢谢

tensorflow
1个回答
2
投票

一种方法是对所有组合进行subtract->abs->sum,然后得到匹配为零的索引。可以使用broadcasting实现。

a = tf.constant([[0,0],[0,1],[1,1]])
b = tf.constant([[0, 0],[0, 1],[0,2],[1, 0],[1, 2]])

find_match = tf.reduce_sum(tf.abs(tf.expand_dims(b,0) - tf.expand_dims(a,1)),2)

indices = tf.transpose(tf.where(tf.equal(find_match, tf.zeros_like(find_match))))[0]

out = tf.gather(a, indices)

with tf.Session() as sess:
   print(sess.run(out))
#Output
#[[0 0]
#[0 1]]
© www.soinside.com 2019 - 2024. All rights reserved.