说,我在张量流中有两个张量如下:
superset = [1,6,7,4,5,6,3,4,8,9,3,2].
subset = [6,3,4,8].
我如何获得指示超集中子集跨度的布尔张量.i.e
我需要一个张量如下:
intersect = [0,0,0,0,0,1,1,1,1,0,0,0].
在张量流中?
这就是我现在的方法,有没有更好的方法?
def getsubset(superset, subset):
result = []
i = 0
while i < len(superset):
print(i)
if i+len(subset) >= len(superset):
result.append(0.0)
i += 1
continue
if all(tf.equal(superset[i:i+len(subset)], subset)):
for k in range(len(subset)):
result.append(1.0)
i += 1
else:
result.append(0.0)
i +=1
return result
这是一个示例,假设
subset
中的每个superset
都是唯一的并且subset
有多个条目:
import tensorflow as tf
def getsubset(superset, subset):
t = tf.reduce_any(tf.equal(superset[:, None], subset), axis=-1)
t = tf.where(tf.logical_and(t[:-1] == t[1:], t[:-1], t[1:]))
indices = tf.concat([t, tf.add(tf.reduce_max(t), 1)[None, None]], axis=0)
return tf.tensor_scatter_nd_update(tf.zeros_like(superset), indices, tf.ones_like(subset))
print(getsubset(tf.constant([1,6,7,4,5,6,3,4,8,9,3,2]), tf.constant([6,3,4,8])))
print(getsubset(tf.constant([1,6,7,4,5,6,3,4,8,9,3,2]), tf.constant([6,7,4])))
print(getsubset(tf.constant([1,6,7,4,5,6,3,4,8,9,3,2]), tf.constant([3,2])))
print(getsubset(tf.constant([1,6,7,4,5,6,3,4,8,9,3,2]), tf.constant([4,8,9])))
tf.Tensor([0 0 0 0 0 1 1 1 1 0 0 0], shape=(12,), dtype=int32)
tf.Tensor([0 1 1 1 0 0 0 0 0 0 0 0], shape=(12,), dtype=int32)
tf.Tensor([0 0 0 0 0 0 0 0 0 0 1 1], shape=(12,), dtype=int32)
tf.Tensor([0 0 0 0 0 0 0 1 1 1 0 0], shape=(12,), dtype=int32)