计算一个Tensor中有多少个元素存在于另一个Tensor中

问题描述 投票:0回答:1

我有两个一维张量:

A = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
B = torch.tensor([2, 5, 6, 8, 12, 15, 16])

张量很大,长度不同,并且值的顺序较近,也没有排序。

我想得到 B 中 (i) 存在于 A 中 (ii) 不存在于 A 中的元素的数量。因此,输出将是:

Exists: 4
Do not exist: 3

我已经尝试过:

exists = torch.eq(A,B).sum().item()
not_exist = torch.numel(B) - exists

但这给出了错误:

RuntimeError: The size of tensor a (10) must match the size of tensor b (7) at non-singleton dimension 0

以下方法有效,但它涉及首先创建

boolean
张量,然后对
true
元素求和。对于非常大的张量是否有效?

exists = np.isin(A,B).sum()
not_exist = torch.numel(B) - exists

有没有更好或更有效的方法?

python pytorch tensor
1个回答
0
投票

尝试以下操作: 进口手电筒

A = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
B = torch.tensor([2, 5, 6, 8, 12, 15, 16])

# Convert tensors to sets
setA = set(A.numpy())
setB = set(B.numpy())

# Find intersection and difference
intersection = setA & setB
difference = setB - setA

# Calculate the counts
exists = len(intersection)
not_exist = len(difference)

print(f"Exists: {exists}")
print(f"Do not exist: {not_exist}")
© www.soinside.com 2019 - 2024. All rights reserved.