所以我的网络输出是这样的。
output = tensor([[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.0410, -0.2234],
[ 0.0362, -0.2111],
[ 0.0333, -0.2018],
[ 0.0318, -0.1951],
[ 0.0311, -0.1904],
[ 0.0310, -0.1873],
[ 0.0312, -0.1851],
[ 0.0315, -0.1837],
[ 0.0318, -0.1828],
[ 0.0322, -0.1822],
[ 0.0324, -0.1819],
[ 0.0327, -0.1817],
[ 0.0328, -0.1815],
[ 0.0330, -0.1815],
[ 0.0331, -0.1814],
[ 0.0332, -0.1814],
[ 0.0333, -0.1814],
[ 0.0333, -0.1814],
[ 0.0334, -0.1814],
[ 0.0334, -0.1814],
[ 0.0334, -0.1814]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.0410, -0.2234],
[ 0.0362, -0.2111],
[ 0.0333, -0.2018],
[ 0.0318, -0.1951],
[ 0.0311, -0.1904],
[ 0.0310, -0.1873],
[ 0.0312, -0.1851],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.0584, -0.2549],
[ 0.0482, -0.2386],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.0716, -0.2668],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]],
[[ 0.0868, -0.2623],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164],
[ 0.1003, -0.2164]]])
这是一个形状为 [8, 24, 2]
现在我的批处理量是8个。我想从每个批次中获取一个数据点,位置如下。
index = tensor([24, 10, 3, 3, 1, 1, 1, 0])
第一批的第24个值,第二批的第10个值,以此类推。
现在,我在理解语法方面遇到了问题。我试过
torch.gather(output, 0, index)
但它一直告诉我,我的尺寸不匹配。
output[ : ,index]
只是让我得到每个批次的所有索引的值。这里正确的语法是什么,才能得到这些值?
如果要在每个批次中只选择一个元素,你需要枚举批次指数,这可以通过以下方法轻松完成。torch.arange
.
output[torch.arange(output.size(0)), index]
这本质上是在枚举张量和你的 index
张量来访问数据,这就导致了索引。output[0, 24]
, output[1, 10]
等。
先说一个小问题,对于一个输出形状[8,24,2],第二段的最大指数可以是23,所以我把你的指数修改成了
index = torch.tensor([23, 10, 3, 3, 1, 1, 1, 0])
output = torch.randn((8,24,2)) # Toy data to represent your output
最简单的解决方案是使用for循环
data_pts = torch.zeros((8,2)) # Tensor to store desired values
for i,j in enumerate(index):
data_pts[i, :] = output[i, j, :]
然而,如果你想将索引矢量化,你只需要所有维度的索引。比如说
data_pts_vectorized = output[range(8), index, :]
由于您的索引向量是有序的,您可以用以下方法生成第一维索引 range
.
你可以确认这两种方法的结果是一样的
assert(torch.all(data_pts == data_pts_vectorized))