我需要按第一列的键值对一批二维矩阵的行进行排序:
原始批量矩阵(3d 张量):
torch.tensor([[[2, 0],
[0, 1],
[1, 2]],
[[1, 2],
[0, 0],
[2, 1]]])
所需张量:
torch.tensor([[[0, 1],
[1, 2],
[2, 0]],
[[0, 0],
[1, 2],
[2, 1]]])
已经知道如何处理其中一个批次,以及另一个答案通过for循环解决问题,这不是并行的。那么如何并行处理整个批次呢?
结果可能有点令人困惑,但很有意义:
(my_tensor[:,torch.argsort(my_tensor[:,:,0], dim=1)])\
[torch.arange(len(my_tensor)),torch.arange(len(my_tensor))]
我在第一行中提取了排序张量
torch.argsort
并将其应用于 my_tensor
,从而产生了 (2, 2, 3, 2)
形状张量。由于您希望每个元素仅根据其第一列进行排序,因此您只对前两个维度的对角线感兴趣,并且可以通过切片(第二行代码)来提取它。