pytorch 用另一个多维张量索引多维张量

问题描述 投票:0回答:1

在 pytorch 中,我有一个形状为 [b, m, n] 的张量 A 和另一个形状为 [b, k] 的张量 B。我想用 B 索引 A。所以结果张量应该具有形状 [b, k, n]。

我尝试进行一些搜索,但没有成功。 torch.index_select 或 torch.take 只能采用一维索引张量。 torch.gather 要求输入张量和索引张量具有相同的形状。

python indexing pytorch
1个回答
0
投票

您尝试在伪代码中执行的操作是:

out[b][k][n] = A[i][B[b][k][n]][n]

要使用

torch.gather
,您确实必须具有相同的维度数。您可以通过在 B 上扩展一个额外的单维来实现
(b, k, n)
的形状。这是一个最小的例子:

A = torch.rand(b,m,n)
B = torch.randint(0,m, (b,k))

展开

B

>>> B_ = B[:,:,None].expand(-1,-1,A.size(-1))

A
收集值:

>>> A.gather(1,B_)
© www.soinside.com 2019 - 2024. All rights reserved.