在 Pytorch 中将 `FloatTensor` 转换为 `HalfTensor`?

问题描述 投票:0回答:2

我正在使用 Transformer 的预训练模型,该模型期望输入类型为

torch.HalfTensor
。但是,我有
torch.FloatTensor
类型的输入。如何转换?

pytorch tensor dtype
2个回答
1
投票

补充学习是一团糟的答案:有几种方法可以将张量从浮点转换为半张量

import torch

t_f = torch.FloatTensor(3, 2) 
print(t_f.dtype)  # torch.float32

# all t_hx are of type torch.float16
t_h1 = t_f.half() # works for cpu and  gpu tensors
t_h2 = t_f.type(torch.HalfTensor)  # only for cpu tensors, use torch.cuda.HalfTensor for gpu tensor 
t_h3 = t_f.to(torch.half)  # .to() works also on models

有关 torch 张量数据类型的更多信息,请参阅 docs


1
投票

torch.Tensor.half()
应该做:

torch.rand(3,3).half()
# tensor([[0.1456, 0.0580, 0.9277],
#        [0.3318, 0.1979, 0.6670],
#        [0.5391, 0.2012, 0.2445]], # dtype=torch.float16)
© www.soinside.com 2019 - 2024. All rights reserved.