获取张量 B 中存在于张量 A 中的值的索引位置

问题描述 投票:0回答:1

我本质上是在寻找一种完全矢量化的方法来获取张量 B:

[1, 2, 3, 9]
和张量 A:
[1,2,3,3,2,1,4,5,9]
,并且对于张量 B 中的每个值,找到其在张量 A 中的索引位置,以便输出为类似于:
[[0,5], [1,4], [2,3], [-1,8]]
(尽管如果它是一维的,只要我可以检索哪个索引对应于张量 B 中的哪些值的信息),其中每一行对应于张量 B 中的一个值,其中列值是给定值在 A 中出现的索引。

这个方法有效:

def vectorized_find_indices(A, B):
    # Expand dimensions of A for broadcasting
    A_expanded = A[:, None, None]

    # Compare B with expanded A to create a boolean mask
    mask = (B == A_expanded)

    # Get the indices where A matches B
    indices = torch.where(mask, torch.arange(A.size(0), device=A.device)[:, None, None], torch.tensor(-1, device=A.device))

    # Reshape the indices to match the shape of B with an additional dimension for indices
    result = indices.permute(1, 2, 0)

    return result

但是我使用的张量对于广播来说太大了,所以我受到的限制更大。

我还尝试了几种更简单的方法,例如

searchsorted
,我找到了这个解决方案:
(A[..., None] == B).any(-1).nonzero()
,它很接近但还不够,因为返回的索引不再直接附加到值。例如,上面的代码片段将返回:
[0, 1, 2, 3, 4, 5, 8]
,这确实是找到匹配项的正确索引,但信息不再嵌套在第二个维度中,将其与我需要的相应值联系起来,但我我对 pytorch 很不熟悉,所以也许有可能以某种方式获取这些信息并使用这些信息重建它?

python numpy pytorch vectorization
1个回答
0
投票

我不确定是否有办法做到这一点,而不需要沿着 B 广播 A 或在 B 上执行 for 循环以减少内存开销。

一个解决方案可能是

overlap_idxs = (a.unsqueeze(1) == b).nonzero()

output = [[] for i in b]

for (a_idx, b_idx) in overlap_idxs:
    output[b_idx].append(a_idx.item())

output
>[[0, 5], [1, 4], [2, 3], [8]]

或者使用 b: 上的 python 级循环:

output = []

for _b in b:
    idxs = (a==_b).nonzero().squeeze().tolist()
    if type(idxs) != list:
        idxs = [idxs]
        
    output.append(idxs)
    
output
>[[0, 5], [1, 4], [2, 3], [8]]
© www.soinside.com 2019 - 2024. All rights reserved.