我本质上是在寻找一种完全矢量化的方法来获取张量 B:
[1, 2, 3, 9]
和张量 A: [1,2,3,3,2,1,4,5,9]
,并且对于张量 B 中的每个值,找到其在张量 A 中的索引位置,以便输出为类似于:[[0,5], [1,4], [2,3], [-1,8]]
(尽管如果它是一维的,只要我可以检索哪个索引对应于张量 B 中的哪些值的信息),其中每一行对应于张量 B 中的一个值,其中列值是给定值在 A 中出现的索引。
这个方法有效:
def vectorized_find_indices(A, B):
# Expand dimensions of A for broadcasting
A_expanded = A[:, None, None]
# Compare B with expanded A to create a boolean mask
mask = (B == A_expanded)
# Get the indices where A matches B
indices = torch.where(mask, torch.arange(A.size(0), device=A.device)[:, None, None], torch.tensor(-1, device=A.device))
# Reshape the indices to match the shape of B with an additional dimension for indices
result = indices.permute(1, 2, 0)
return result
但是我使用的张量对于广播来说太大了,所以我受到的限制更大。
我还尝试了几种更简单的方法,例如
searchsorted
,我找到了这个解决方案:(A[..., None] == B).any(-1).nonzero()
,它很接近但还不够,因为返回的索引不再直接附加到值。例如,上面的代码片段将返回:[0, 1, 2, 3, 4, 5, 8]
,这确实是找到匹配项的正确索引,但信息不再嵌套在第二个维度中,将其与我需要的相应值联系起来,但我我对 pytorch 很不熟悉,所以也许有可能以某种方式获取这些信息并使用这些信息重建它?
我不确定是否有办法做到这一点,而不需要沿着 B 广播 A 或在 B 上执行 for 循环以减少内存开销。
一个解决方案可能是
overlap_idxs = (a.unsqueeze(1) == b).nonzero()
output = [[] for i in b]
for (a_idx, b_idx) in overlap_idxs:
output[b_idx].append(a_idx.item())
output
>[[0, 5], [1, 4], [2, 3], [8]]
或者使用 b: 上的 python 级循环:
output = []
for _b in b:
idxs = (a==_b).nonzero().squeeze().tolist()
if type(idxs) != list:
idxs = [idxs]
output.append(idxs)
output
>[[0, 5], [1, 4], [2, 3], [8]]