我有两个张量为1000 * 1的张量。我想检查两个张量中1000个元素中有多少相等。我认为我应该能够像Numpy一样在一行中执行此操作,但是找不到类似的功能。
您可以只使用==
运算符检查是否相等,然后对得出的张量求和:
# Import torch and create dummy tensors
>>> import torch
>>> A = torch.randint(2, (10,))
>>> A
tensor([0, 0, 0, 1, 0, 1, 0, 0, 1, 1])
>>> B = torch.randint(2, (10,))
>>> B
tensor([0, 1, 1, 0, 1, 0, 0, 1, 1, 0])
# Checking for number of equal values
>>> (A == B).sum()
tensor(3)
编辑:
torch.eq
得出相同的结果。因此,如果您出于某些原因更喜欢:
>>> torch.eq(A, B).sum()
tensor(3)
类似
equal_count = len((tensor_1.flatten() == tensor_2.flatten()).nonzero().flatten())
应该工作。