在 PyTorch 中打乱两个 2D 张量并保持相同的阶相关性

问题描述 投票:0回答:1

是否可以在 PyTorch 中按行对两个 2D 张量进行打乱,但保持两者的顺序相同?我知道您可以使用以下代码按行对 2D 张量进行洗牌:

a=a[torch.randperm(a.size()[0])]

详细说明: 如果我有 2 个张量

a = torch.tensor([[1, 1, 1, 1, 1],
            [2, 2, 2, 2, 2],
            [3, 3, 3, 3, 3]])

b = torch.tensor([[4, 4, 4, 4, 4],
            [5, 5, 5, 5, 5],
            [6, 6, 6, 6, 6]])

并通过一些函数/代码块运行它们以随机洗牌,但保持相关性并产生如下所示的内容

a = torch.tensor([[2, 2, 2, 2, 2],
            [1, 1, 1, 1, 1],
            [3, 3, 3, 3, 3]])

b = torch.tensor([[5, 5, 5, 5, 5],
            [4, 4, 4, 4, 4],
            [6, 6, 6, 6, 6]])

我当前的解决方案是使用 random.shuffle() 函数转换为列表,如下所示。

a_list = a.tolist()
b_list = b.tolist()
temp_list = list(zip(a_list , b_list ))
random.shuffle(temp_list) # Shuffle
a_temp, b_temp = zip(*temp_list)
a_list, b_list = list(a_temp), list(b_temp)
            
# Convert back to tensors
a = torch.tensor(a_list)
b = torch.tensor(b_list)

这需要相当长的时间,想知道是否有更好的方法。

python pytorch 2d tensor
1个回答
7
投票

您可以使用函数

torch.randperm
来获取一组充当随机排列的索引。以下是获取随机排列,然后将其应用于
a
b
张量的小示例:

indices = torch.randperm(a.size()[0])
a=a[indices]
b=b[indices]
© www.soinside.com 2019 - 2024. All rights reserved.