假设我得到以下张量:
arr = torch.randint(0, 9, (100, 50, 3))
我想要实现的是收集,例如,该张量的 2 个元素,让我们从收集第 6 个和第 56 个元素开始:
indices = torch.tensor([5, 55])
partial_arr = arr[indices]
这给了我一个形状数组
torch.Size([2, 50, 3])
现在,我们假设从第一个元素开始,我想收集元素 5 到 10
first_result = partial_arr[0, 5:10]
从第二个元素开始,第10到15个元素:
second_result = partial_arr[1, 10:15]
因为我想要一个张量中的所有内容,所以我可以这样做:
final_result = torch.cat([first_result, second_result])
如何仅对第一个张量进行一次操作即可获得最终结果:
arr = torch.randint(0, 9, (100, 50, 3))
?
假设切片元素的数量在各行中保持不变,您可以创建一个排列张量并将其移动每行的起始索引:
>>> idx = torch.tensor([5,10])
>>> idx_ = torch.arange(5,)[None]+idx[:,None]
tensor([[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
然后展开
idx_
,使其具有与 partial_arr
相同的最后一个维度大小:
>>> idx_ = idx_[...,None].expand(-1,-1,partial_arr.size(-1))
# shaped torch.Size([2, 5, 3])
torch.gather
: 收集值
>>> partial_arr.gather(1,idx_).shape
tensor([[[8, 3, 1],
[2, 4, 6],
[4, 4, 5],
[2, 8, 6],
[3, 7, 0]],
[[3, 6, 7],
[5, 7, 4],
[1, 5, 4],
[4, 5, 3],
[7, 1, 2]]])