如何序列化充满复数的火炬张量?

问题描述 投票:0回答:1

我有一些这样的数据:

tensor([[0.9938+0.j, 0.1109+0.j],
        [1.0000+0.j, 0.0000+0.j],
        [0.9450+0.j, 0.3272+0.j],
        [0.9253+0.j, 0.3792+0.j],
        [0.9450+0.j, 0.3272+0.j],
        [0.9028+0.j, 0.4300+0.j],
        [0.8776+0.j, 0.4794+0.j],
        [0.9253+0.j, 0.3792+0.j],
        [0.9253+0.j, 0.3792+0.j],
        [0.9985+0.j, 0.0555+0.j]])

我想将这些数据保存到 JSON 文件中。
我试图通过做

data.numpy().tostring()
将它转换为字符串,但它给了我一个错误说
TypeError: Object of type bytes is not JSON serializable
.

有没有办法把它写成 JSON 然后再读回来?

python json tensor torch
1个回答
0
投票

嗨,pytorch 可以使用 save 保存字典,您可以使用 .load 加载

举个例子:

import torch

save_dict = {
   "test": torch.randn(1, 100),
   "test2": torch.zeros(100),
   "test3": torch.ones(100)
}

torch.save(save_dict, 'test.pt')
torch.load('test.pt')
© www.soinside.com 2019 - 2024. All rights reserved.