PyTorch 找到两个大张量之间匹配行的索引

问题描述 投票:0回答:0

我有两个二维张量,

A
B
。我想写一个函数
find_indices(A, B)
,它返回一个一维张量,其中包含
A
中的行索引,它也出现在
B
中。此外,该函数应避免使用
for
循环进行并行化。例如:

import torch

A = torch.tensor([[1, 2, 3], [2, 3, 4], [3, 4, 5]])
B = torch.tensor([[1, 2, 3], [2, 3, 6], [2, 5, 6], [3, 4, 5]])

indices1 = find_indices(A, B)  # tensor([0, 2])
indices2 = find_indices(B, A)  # tensor([0, 3])

assert A[indices1].equal(B[indices2])

假设:

  • A
    B
    中的所有行都是唯一的。
  • A
    B
    中的行都已排序。所以相同的两行在
    A
    B
    中以相同的顺序出现。
  • len(A)
    len(B)
    是 ~200k.

我已经从https://stackoverflow.com/a/60494505/17495278:

尝试过这种方法
values, indices = torch.topk(((A.t() == B.unsqueeze(-1)).all(dim=1)).int(), 1, 1)
indices = indices[values!=0]
# indices = tensor([0, 2])

它给出了小规模输入的准确答案。但对于我的用例,它需要 >100 GB 内存并引发 CUDA 内存不足错误。有没有另一种方法可以以合理的内存成本(比如 1 GB 以下)实现这一目标?

python pytorch tensor
© www.soinside.com 2019 - 2024. All rights reserved.