在 pytorch 中,我有一个形状为 [b, m, n] 的张量 A 和另一个形状为 [b, k] 的张量 B。我想用 B 索引 A。所以结果张量应该具有形状 [b, k, n]。
我尝试进行一些搜索,但没有成功。 torch.index_select 或 torch.take 只能采用一维索引张量。 torch.gather 要求输入张量和索引张量具有相同的形状。
您尝试在伪代码中执行的操作是:
out[b][k][n] = A[i][B[b][k][n]][n]
torch.gather
,您确实必须具有相同的维度数。您可以通过在 B 上扩展一个额外的单维来实现 (b, k, n)
的形状。这是一个最小的例子:
A = torch.rand(b,m,n)
B = torch.randint(0,m, (b,k))
展开
B
:
>>> B_ = B[:,:,None].expand(-1,-1,A.size(-1))
从
A
收集值:
>>> A.gather(1,B_)