当您尝试将 Torch bfloat16 张量转换为 numpy 数组时,它会抛出
TypeError
:
import torch
x = torch.Tensor([0]).to(torch.bfloat16)
x.numpy() # TypeError: Got unsupported ScalarType BFloat16
import numpy as np
np.array(x) # same error
有解决方法可以进行此转换吗?
目前,
numpy
不支持bfloat16**。一种解决方法是在进行转换之前将张量从半精度向上转换为单精度:
x.float().numpy()
Pytorch 维护者也在考虑自动向
force=True
方法添加 Tensor.numpy
选项。