如何将 pyTorch 张量字典元组转换为张量字典?

问题描述 投票:0回答:1

我有一个字典元组,其中包含 pyTorch 张量:

tuple_of_dicts_of_tensors = (
    {'key_1': torch.tensor([1,1,1]), 'key_2': torch.tensor([4,4,4])},
    {'key_1': torch.tensor([2,2,2]), 'key_2': torch.tensor([5,5,5])},
    {'key_1': torch.tensor([3,3,3]), 'key_2': torch.tensor([6,6,6])}
)

我想将其转换成张量字典:

dict_of_tensors = {
    'key_1': torch.tensor([[1,1,1], [2,2,2], [3,3,3]]),
    'key_2': torch.tensor([[4,4,4], [5,5,5], [6,6,6]])
}

您建议如何这样做?最有效的方法是什么? 张量位于 GPU 设备上,因此需要最少量的 for 循环。

谢谢!

python python-3.x dictionary pytorch tuples
1个回答
0
投票

在此代码中,首先使用字典理解从字典元组中提取每个键的张量来构造 dict_of_lists_of_tensors。然后,对于每个键,使用 torch.stack() 沿新维度堆叠张量列表,这将为您提供所需的张量字典。这种方法最大限度地减少了显式循环的使用,并利用了 PyTorch 的 GPU 加速操作。

import torch

tuple_of_dicts_of_tensors = (
    {'key_1': torch.tensor([1,1,1]), 'key_2': torch.tensor([4,4,4])},
    {'key_1': torch.tensor([2,2,2]), 'key_2': torch.tensor([5,5,5])},
    {'key_1': torch.tensor([3,3,3]), 'key_2': torch.tensor([6,6,6])}
)
dict_of_lists_of_tensors = {key: [d[key] for d in tuple_of_dicts_of_tensors] for key in tuple_of_dicts_of_tensors[0]}
dict_of_tensors = {key: torch.stack(tensor_list) for key, tensor_list in dict_of_lists_of_tensors.items()}

print(dict_of_tensors)

输出:-

{'key_1': tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]]), 'key_2': tensor([[4, 4, 4],
        [5, 5, 5],
        [6, 6, 6]])}

© www.soinside.com 2019 - 2024. All rights reserved.