我有一个字典元组,其中包含 pyTorch 张量:
tuple_of_dicts_of_tensors = (
{'key_1': torch.tensor([1,1,1]), 'key_2': torch.tensor([4,4,4])},
{'key_1': torch.tensor([2,2,2]), 'key_2': torch.tensor([5,5,5])},
{'key_1': torch.tensor([3,3,3]), 'key_2': torch.tensor([6,6,6])}
)
我想将其转换成张量字典:
dict_of_tensors = {
'key_1': torch.tensor([[1,1,1], [2,2,2], [3,3,3]]),
'key_2': torch.tensor([[4,4,4], [5,5,5], [6,6,6]])
}
您建议如何这样做?最有效的方法是什么? 张量位于 GPU 设备上,因此需要最少量的 for 循环。
谢谢!
在此代码中,首先使用字典理解从字典元组中提取每个键的张量来构造 dict_of_lists_of_tensors。然后,对于每个键,使用 torch.stack() 沿新维度堆叠张量列表,这将为您提供所需的张量字典。这种方法最大限度地减少了显式循环的使用,并利用了 PyTorch 的 GPU 加速操作。
import torch
tuple_of_dicts_of_tensors = (
{'key_1': torch.tensor([1,1,1]), 'key_2': torch.tensor([4,4,4])},
{'key_1': torch.tensor([2,2,2]), 'key_2': torch.tensor([5,5,5])},
{'key_1': torch.tensor([3,3,3]), 'key_2': torch.tensor([6,6,6])}
)
dict_of_lists_of_tensors = {key: [d[key] for d in tuple_of_dicts_of_tensors] for key in tuple_of_dicts_of_tensors[0]}
dict_of_tensors = {key: torch.stack(tensor_list) for key, tensor_list in dict_of_lists_of_tensors.items()}
print(dict_of_tensors)
输出:-
{'key_1': tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]]), 'key_2': tensor([[4, 4, 4],
[5, 5, 5],
[6, 6, 6]])}