给定张量 A 形状(d0, d1, ..., dn, dn+1)和具有形状的已排序索引 I 的张量(d0, d1, ..., dn) 我想使用 I.
中的排序索引重新排序 A 的索引张量的前n维A和I相等,张量A的第(n+1)维可以任意大小
给定A和I:
>>> A.shape
torch.Size([8, 8, 4])
>>> A
tensor([[[5.6065e-01, 3.1521e-01, 5.7780e-01, 6.7756e-01],
[9.9534e-04, 7.6054e-01, 9.0428e-01, 4.1251e-01],
[8.1525e-01, 3.0477e-01, 3.9605e-01, 2.9155e-01],
[4.9588e-01, 7.4128e-01, 8.8521e-01, 6.1442e-01],
[4.3290e-01, 2.4908e-01, 9.0862e-01, 2.6999e-01],
[9.8264e-01, 4.9388e-01, 4.9769e-01, 2.7884e-02],
[5.7816e-01, 7.5621e-01, 7.0113e-01, 4.4830e-01],
[7.2809e-01, 8.6010e-01, 7.8921e-01, 1.1440e-01]],
...])
>>> I.shape
torch.Size([8, 8])
>>> I
tensor([[2, 7, 4, 6, 1, 3, 0, 5],
...])
A 的倒数第二个维度的元素重新排序后应该是这样的:
>>> A
tensor([[[8.1525e-01, 3.0477e-01, 3.9605e-01, 2.9155e-01],
[7.2809e-01, 8.6010e-01, 7.8921e-01, 1.1440e-01],
[4.3290e-01, 2.4908e-01, 9.0862e-01, 2.6999e-01],
[5.7816e-01, 7.5621e-01, 7.0113e-01, 4.4830e-01],
[9.9534e-04, 7.6054e-01, 9.0428e-01, 4.1251e-01],
[4.9588e-01, 7.4128e-01, 8.8521e-01, 6.1442e-01],
[5.6065e-01, 3.1521e-01, 5.7780e-01, 6.7756e-01],
[9.8264e-01, 4.9388e-01, 4.9769e-01, 2.7884e-02]],
...])
为简单起见,我只包括张量的第一行A和I.
基于已接受的答案,我实现了一个通用版本,可以对任意数量或维度的任意张量进行排序 (d0, d1, ..., dn, dn+1, d n+2, , d..., dn+k) 给定一个排序索引张量 (d0, d1, ..., dn).
这里是代码片段:
import torch
from torch import LongTensor, Tensor
def sort_by_indices(values: Tensor, indices: LongTensor) -> Tensor:
new_shape = tuple(indices.shape) + tuple(
1
for _ in range(values.dim() - indices.dim())
)
repeat_dims = tuple(
1
for _ in range(indices.dim())
) + tuple(values.shape[indices.dim():])
indices = torch.tile(indices.reshape(*new_shape), repeat_dims)
return torch.gather(values, indices.dim() - 1, indices)
torch.gather
但您需要重塑和 tile
指数如下:
(为了更好地展示,我更改了 (8, 8, 4) -> (4, 4, 2) 和 (8, 8) -> (4, 4))
import torch
torch.manual_seed(2023)
A = torch.rand(4, 4, 2)
# First A
# >>> A
# tensor([[[0.4290, 0.7201],
# [0.9481, 0.4797],
# [0.5414, 0.9906],
# [0.4086, 0.2183]],
# [[0.1834, 0.2852],
# [0.7813, 0.1048],
# [0.6550, 0.8375],
# [0.1823, 0.5239]],
# [[0.2432, 0.9644],
# [0.5034, 0.0320],
# [0.8316, 0.3807],
# [0.3539, 0.2114]],
# [[0.9839, 0.6632],
# [0.7001, 0.0155],
# [0.3840, 0.7968],
# [0.4917, 0.4324]]])
B = torch.tensor([
[0, 2, 3, 1],
[1, 3, 0, 2],
[3, 1, 2, 0],
[2, 0, 1, 3]
])
B_changed = torch.tile(B[..., None], (1,1,A.shape[2]))
A_new = torch.gather(a, dim = 1, index = B_changed)
print(A_new)
输出:
tensor([[[0.4290, 0.7201],
[0.5414, 0.9906],
[0.4086, 0.2183],
[0.9481, 0.4797]],
[[0.7813, 0.1048],
[0.1823, 0.5239],
[0.1834, 0.2852],
[0.6550, 0.8375]],
[[0.3539, 0.2114],
[0.5034, 0.0320],
[0.8316, 0.3807],
[0.2432, 0.9644]],
[[0.3840, 0.7968],
[0.9839, 0.6632],
[0.7001, 0.0155],
[0.4917, 0.4324]]])