RuntimeError:_th_normal在LongType上不受支持

问题描述 投票:0回答:1

我正在尝试使用以下方法从正态分布生成一个数字:

from torch.distributions import Normal
noise = Normal(th.tensor([0]), th.tensor(3.20))
noise = noise.sample()

但是我收到此错误:RuntimeError: _th_normal not supported on CPUType for Long

python-3.x pytorch normal-distribution
1个回答
0
投票

您的第一个张量th.tensor([0]) 属于torch.Long类型由于根据传递的值会自动进行类型推断,而功能需要floatFloatTensor

您可以通过像这样显式传递0.0来解决它:

import torch

noise = torch.distributions.Normal(torch.tensor([0.0]), torch.tensor(3.20))
noise = noise.sample()

更好的是,完全删除torch.tensor,在这种情况下,如果可能,Python类型将自动转换为float,因此这也是有效的:

import torch

noise = torch.distributions.Normal(0, 3.20)
noise = noise.sample()

并且请不要将torch别名为th,这不是官方名称,请使用完全限定的名称,因为这只会使所有人感到困惑。

© www.soinside.com 2019 - 2024. All rights reserved.