我有一个 B x M x N 张量 X,并且我有 B x 1 张量 Y,它对应于我想要保留的维度 = 1 处的张量 X 的索引。这个切片的简写是什么,这样我就可以避免循环?
本质上我想这样做:
Z = torch.zeros(B,N)
for i in range(B):
Z[i] = X[i][Y[i]]
下面的代码与循环中的代码类似。区别在于,我们不是按顺序索引数组
Z
、X
和 Y
,而是使用数组 i
并行索引它们
B, M, N = 13, 7, 19
X = np.random.randint(100, size= [B,M,N])
Y = np.random.randint(M , size= [B,1])
Z = np.random.randint(100, size= [B,N])
i = np.arange(B)
Y = Y.ravel() # reducing array to rank-1, for easy indexing
Z[i] = X[i,Y[i],:]
这段代码可以进一步简化为
>> Z[i] = X[i,Y[i],:]
>> Z[i] = X[i,Y[i]]
>> Z[i] = X[i,Y]
>> Z = X[i,Y]
pytorch 等效代码
B, M, N = 5, 7, 3
X = torch.randint(100, size= [B,M,N])
Y = torch.randint(M , size= [B,1])
Z = torch.randint(100, size= [B,N])
i = torch.arange(B)
Y = Y.ravel()
Z = X[i,Y]
@Hammad 提供的答案很简短,非常适合这项工作。如果您有兴趣使用一些鲜为人知的 Pytorch 内置程序,这里有一个替代解决方案。我们将使用 torch.gather
numpy.take
来实现)。
torch.gather
背后的想法是基于两个形状相同的张量构建一个新的张量,其中包含索引(此处〜
Y
)和值(此处〜
X
)。执行的操作是
Z[i][j][k] = X[i][Y[i][j][k]][k]
。由于
X
的形状是
(B, M, N)
且
Y
形状是
(B, 1)
,我们希望填补
Y
内的空白,使
Y
的形状变成
(B, 1, N)
。这可以通过一些轴操作来实现:
>>> Y.expand(-1, N)[:, None] # expand to dim=1 to N and unsqueeze dim=1
对 torch.gather
的实际调用将是:
>>> X.gather(dim=1, index=Y.expand(-1, N)[:, None])
您可以通过添加 (B, N)
将其重塑为
[:, 0]
。