在pytorch张量的一个轴上选择多个索引

问题描述 投票:0回答:1

我的实际问题是在更高的维度上,但我将其发布在更小的维度上以使其易于可视化。

我有一个形状为 (2,3,4) 的张量:

x = torch.randn(2, 3, 4)

tensor([[[-0.9118,  1.4676, -0.4684, -0.6343],
         [ 1.5649,  1.0218, -1.3703,  1.8961],
         [ 0.8652,  0.2491, -0.2556,  0.1311]],

        [[ 0.5289, -1.2723,  2.3865,  0.0222],
         [-1.5528, -0.4638, -0.6954,  0.1661],
         [-1.8151, -0.4634,  1.6490,  0.6957]]])

从这个张量中,我需要选择由沿

axis-1
的索引列表给出的行。

例子,

indices = torch.tensor([0, 2])

预期产出:

tensor([[[-0.9118,  1.4676, -0.4684, -0.6343]],
        [[-1.8151, -0.4634,  1.6490,  0.6957]]])

输出形状:

(2,1,4)

解释: 从x[0]中选择第0行,从x[1]中选择第2行。 (来自指数)

我试过像这样使用

index_select

torch.index_select(x, 1, indices)

但问题是它正在为 x 中的每个项目选择第 0 行和第 2 行。看起来它需要一些修改我现在无法弄清楚。

python multidimensional-array pytorch tensor
1个回答
0
投票

就您而言,这非常简单。平行导航两个维度 的一种简单方法是在第一个轴上使用范围,在第二个轴上使用索引张量:

>>> x[range(len(indices)), indices]
tensor([[-0.9118,  1.4676, -0.4684, -0.6343],
        [-1.8151, -0.4634,  1.6490,  0.6957]])

在更一般的情况下,这需要使用

torch.gather

  • 首先展开索引,使其具有足够的维度:

    index = indices[:,None,None].expand(x.size(0), -1, x.size(-1))
    
  • 然后您可以在

    x
    index
    上应用该功能并挤压
    dim=1

    >>> x.gather(dim=-2, index=index)[:,0]
    tensor([[-0.9118,  1.4676, -0.4684, -0.6343],
            [-1.8151, -0.4634,  1.6490,  0.6957]])
    
© www.soinside.com 2019 - 2024. All rights reserved.