我有两个二维张量,
A
和 B
。我想写一个函数find_indices(A, B)
,它返回一个一维张量,其中包含A
中的行索引,它也出现在B
中。此外,该函数应避免使用 for
循环进行并行化。例如:
import torch
A = torch.tensor([[1, 2, 3], [2, 3, 4], [3, 4, 5]])
B = torch.tensor([[1, 2, 3], [2, 3, 6], [2, 5, 6], [3, 4, 5]])
indices1 = find_indices(A, B) # tensor([0, 2])
indices2 = find_indices(B, A) # tensor([0, 3])
assert A[indices1].equal(B[indices2])
假设:
A
和B
中的所有行都是唯一的。A
和 B
中的行都已排序。所以相同的两行在A
和B
中以相同的顺序出现。len(A)
和 len(B)
是 ~200k.我已经从https://stackoverflow.com/a/60494505/17495278:
尝试过这种方法values, indices = torch.topk(((A.t() == B.unsqueeze(-1)).all(dim=1)).int(), 1, 1)
indices = indices[values!=0]
# indices = tensor([0, 2])
它给出了小规模输入的准确答案。但对于我的用例,它需要 >100 GB 内存并引发 CUDA 内存不足错误。有没有另一种方法可以以合理的内存成本(比如 1 GB 以下)实现这一目标?