Pytorch批量索引

问题描述 投票:1回答:1

所以我的网络输出是这样的。

output = tensor([[[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.0410, -0.2234],
     [ 0.0362, -0.2111],
     [ 0.0333, -0.2018],
     [ 0.0318, -0.1951],
     [ 0.0311, -0.1904],
     [ 0.0310, -0.1873],
     [ 0.0312, -0.1851],
     [ 0.0315, -0.1837],
     [ 0.0318, -0.1828],
     [ 0.0322, -0.1822],
     [ 0.0324, -0.1819],
     [ 0.0327, -0.1817],
     [ 0.0328, -0.1815],
     [ 0.0330, -0.1815],
     [ 0.0331, -0.1814],
     [ 0.0332, -0.1814],
     [ 0.0333, -0.1814],
     [ 0.0333, -0.1814],
     [ 0.0334, -0.1814],
     [ 0.0334, -0.1814],
     [ 0.0334, -0.1814]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.0410, -0.2234],
     [ 0.0362, -0.2111],
     [ 0.0333, -0.2018],
     [ 0.0318, -0.1951],
     [ 0.0311, -0.1904],
     [ 0.0310, -0.1873],
     [ 0.0312, -0.1851],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.0584, -0.2549],
     [ 0.0482, -0.2386],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.0716, -0.2668],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]],

    [[ 0.0868, -0.2623],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164],
     [ 0.1003, -0.2164]]])

这是一个形状为 [8, 24, 2]现在我的批处理量是8个。我想从每个批次中获取一个数据点,位置如下。

index = tensor([24, 10,  3,  3,  1,  1,  1,  0])

第一批的第24个值,第二批的第10个值,以此类推。

现在,我在理解语法方面遇到了问题。我试过

torch.gather(output, 0, index)

但它一直告诉我,我的尺寸不匹配。

output[ : ,index]

只是让我得到每个批次的所有索引的值。这里正确的语法是什么,才能得到这些值?

python indexing pytorch batch-processing
1个回答
0
投票

如果要在每个批次中只选择一个元素,你需要枚举批次指数,这可以通过以下方法轻松完成。torch.arange.

output[torch.arange(output.size(0)), index]

这本质上是在枚举张量和你的 index 张量来访问数据,这就导致了索引。output[0, 24], output[1, 10] 等。


0
投票

先说一个小问题,对于一个输出形状[8,24,2],第二段的最大指数可以是23,所以我把你的指数修改成了

index = torch.tensor([23, 10,  3,  3,  1,  1,  1,  0])
output = torch.randn((8,24,2)) # Toy data to represent your output

最简单的解决方案是使用for循环

data_pts = torch.zeros((8,2)) # Tensor to store desired values

for i,j in enumerate(index):
    data_pts[i, :] = output[i, j, :]

然而,如果你想将索引矢量化,你只需要所有维度的索引。比如说

data_pts_vectorized = output[range(8), index, :] 

由于您的索引向量是有序的,您可以用以下方法生成第一维索引 range.

你可以确认这两种方法的结果是一样的

assert(torch.all(data_pts == data_pts_vectorized))
© www.soinside.com 2019 - 2024. All rights reserved.