将 Pytorch bfloat16 张量转换为 numpy 会抛出 TypeError

问题描述 投票:0回答:1

当您尝试将 Torch bfloat16 张量转换为 numpy 数组时,它会抛出

TypeError
:

import torch

x = torch.Tensor([0]).to(torch.bfloat16)
x.numpy()  # TypeError: Got unsupported ScalarType BFloat16

import numpy as np
np.array(x)  # same error

有解决方法可以进行此转换吗?

python numpy pytorch floating-point tensor
1个回答
0
投票

目前,

numpy
不支持bfloat16**。一种解决方法是在进行转换之前将张量从半精度向上转换为单精度:

x.float().numpy()

Pytorch 维护者也在考虑自动向

force=True
方法添加
Tensor.numpy
选项。

**虽然可能会改变感谢@jakevdp工作

© www.soinside.com 2019 - 2024. All rights reserved.