我的实际问题是在更高的维度上,但我将其发布在更小的维度上以使其易于可视化。
我有一个形状为 (2,3,4) 的张量:
x = torch.randn(2, 3, 4)
tensor([[[-0.9118, 1.4676, -0.4684, -0.6343],
[ 1.5649, 1.0218, -1.3703, 1.8961],
[ 0.8652, 0.2491, -0.2556, 0.1311]],
[[ 0.5289, -1.2723, 2.3865, 0.0222],
[-1.5528, -0.4638, -0.6954, 0.1661],
[-1.8151, -0.4634, 1.6490, 0.6957]]])
从这个张量中,我需要选择由沿
axis-1
的索引列表给出的行。
例子,
indices = torch.tensor([0, 2])
预期产出:
tensor([[[-0.9118, 1.4676, -0.4684, -0.6343]],
[[-1.8151, -0.4634, 1.6490, 0.6957]]])
输出形状:
(2,1,4)
解释: 从x[0]中选择第0行,从x[1]中选择第2行。 (来自指数)
我试过像这样使用
index_select
:
torch.index_select(x, 1, indices)
但问题是它正在为 x 中的每个项目选择第 0 行和第 2 行。看起来它需要一些修改我现在无法弄清楚。
就您而言,这非常简单。平行导航两个维度 的一种简单方法是在第一个轴上使用范围,在第二个轴上使用索引张量:
>>> x[range(len(indices)), indices]
tensor([[-0.9118, 1.4676, -0.4684, -0.6343],
[-1.8151, -0.4634, 1.6490, 0.6957]])
torch.gather
:
首先展开索引,使其具有足够的维度:
index = indices[:,None,None].expand(x.size(0), -1, x.size(-1))
然后您可以在
x
和index
上应用该功能并挤压dim=1
:
>>> x.gather(dim=-2, index=index)[:,0]
tensor([[-0.9118, 1.4676, -0.4684, -0.6343],
[-1.8151, -0.4634, 1.6490, 0.6957]])