从多个维度的张量中收集不同的元素索引

问题描述 投票:0回答:1

假设我得到以下张量:

arr = torch.randint(0, 9, (100, 50, 3))

我想要实现的是收集,例如,该张量的 2 个元素,让我们从收集第 6 个和第 56 个元素开始:

indices = torch.tensor([5, 55])
partial_arr = arr[indices]

这给了我一个形状数组

torch.Size([2, 50, 3])

现在,我们假设从第一个元素开始,我想收集元素 5 到 10

first_result = partial_arr[0, 5:10]

从第二个元素开始,第10到15个元素:

second_result = partial_arr[1, 10:15]

因为我想要一个张量中的所有内容,所以我可以这样做:

final_result = torch.cat([first_result, second_result])

如何仅对第一个张量进行一次操作即可获得最终结果:

arr = torch.randint(0, 9, (100, 50, 3))

python indexing pytorch
1个回答
0
投票

假设切片元素的数量在各行中保持不变,您可以创建一个排列张量并将其移动每行的起始索引:

>>> idx = torch.tensor([5,10])
>>> idx_ = torch.arange(5,)[None]+idx[:,None]
tensor([[ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14]])

然后展开

idx_
,使其具有与
partial_arr
相同的最后一个维度大小:

>>> idx_ = idx_[...,None].expand(-1,-1,partial_arr.size(-1)) 
# shaped torch.Size([2, 5, 3])

最后,使用

torch.gather
:

收集值
>>> partial_arr.gather(1,idx_).shape
tensor([[[8, 3, 1],
         [2, 4, 6],
         [4, 4, 5],
         [2, 8, 6],
         [3, 7, 0]],

        [[3, 6, 7],
         [5, 7, 4],
         [1, 5, 4],
         [4, 5, 3],
         [7, 1, 2]]])
© www.soinside.com 2019 - 2024. All rights reserved.