如何按特定键值“批量”排序张量?

问题描述 投票:0回答:1

我需要按第一列的键值对一批二维矩阵的行进行排序:

原始批量矩阵(3d 张量):

torch.tensor([[[2, 0], 
               [0, 1],
               [1, 2]],

              [[1, 2], 
               [0, 0], 
               [2, 1]]])

所需张量:

torch.tensor([[[0, 1],
               [1, 2],
               [2, 0]],

              [[0, 0],
               [1, 2],  
               [2, 1]]])

已经知道如何处理其中一个批次,以及另一个答案通过for循环解决问题,这不是并行的。那么如何并行处理整个批次呢?

numpy pytorch tensor
1个回答
0
投票

结果可能有点令人困惑,但很有意义:

(my_tensor[:,torch.argsort(my_tensor[:,:,0], dim=1)])\
[torch.arange(len(my_tensor)),torch.arange(len(my_tensor))]

我在第一行中提取了排序张量

torch.argsort
并将其应用于
my_tensor
,从而产生了
(2, 2, 3, 2)
形状张量。由于您希望每个元素仅根据其第一列进行排序,因此您只对前两个维度的对角线感兴趣,并且可以通过切片(第二行代码)来提取它。

© www.soinside.com 2019 - 2024. All rights reserved.