批量张量切片,切片 B x N x M 且 B x 1

问题描述 投票:0回答:2

我有一个 B x M x N 张量 X,并且我有 B x 1 张量 Y,它对应于我想要保留的维度 = 1 处的张量 X 的索引。这个切片的简写是什么,这样我就可以避免循环?

本质上我想这样做:

Z = torch.zeros(B,N)

for i in range(B):
    Z[i] = X[i][Y[i]]
numpy pytorch slice tensor numpy-slicing
2个回答
3
投票

下面的代码与循环中的代码类似。区别在于,我们不是按顺序索引数组

Z
X
Y
,而是使用数组
i

并行索引它们
B, M, N = 13, 7, 19

X = np.random.randint(100, size= [B,M,N])
Y = np.random.randint(M  , size= [B,1])
Z = np.random.randint(100, size= [B,N])

i = np.arange(B)
Y = Y.ravel()    # reducing array to rank-1, for easy indexing

Z[i] = X[i,Y[i],:]

这段代码可以进一步简化为

>> Z[i] = X[i,Y[i],:]
>> Z[i] = X[i,Y[i]]
>> Z[i] = X[i,Y]
>> Z    = X[i,Y]

pytorch 等效代码

B, M, N = 5, 7, 3

X = torch.randint(100, size= [B,M,N])
Y = torch.randint(M  , size= [B,1])
Z = torch.randint(100, size= [B,N])

i = torch.arange(B)
Y = Y.ravel()

Z = X[i,Y]

1
投票

@Hammad 提供的答案很简短,非常适合这项工作。如果您有兴趣使用一些鲜为人知的 Pytorch 内置程序,这里有一个替代解决方案。我们将使用 torch.gather

(同样,您可以使用 numpy.take
 来实现)。

torch.gather

背后的想法是基于两个形状相同的张量构建一个新的张量,其中包含索引(此处〜
Y
)和值(此处〜
X
)。

执行的操作是

Z[i][j][k] = X[i][Y[i][j][k]][k]

由于

X

 的形状是 
(B, M, N)
Y
 形状是 
(B, 1)
,我们希望填补 
Y
 内的空白,使 
Y
 的形状变成 
(B, 1, N)

这可以通过一些轴操作来实现:

>>> Y.expand(-1, N)[:, None] # expand to dim=1 to N and unsqueeze dim=1

torch.gather

 的实际调用将是:

>>> X.gather(dim=1, index=Y.expand(-1, N)[:, None])
您可以通过添加 

(B, N)

 将其重塑为 
[:, 0]


此功能在棘手的场景中非常有效...

© www.soinside.com 2019 - 2024. All rights reserved.