PyTorch load_state_dict() 未加载精确值

问题描述 投票:0回答:1

为了简单起见,我想使用此代码将火炬模型的所有参数设置为常量

72114982

model = Net()
params = model.state_dict()

for k, v in params.items():
    params[k] = torch.full(v.shape, 72114982, dtype=torch.long) 

model.load_state_dict(params)
print(model.state_dict().values())

然后 print 语句显示所有值实际上都设置为

72114984
,与我最初想要的值相差 2。

为了简单起见,定义

Net
如下

class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(2, 2, 2)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(2, 2)
pytorch precision
1个回答
0
投票

这是数据类型的问题。

模型参数被转换为浮点张量。

72114984
足够大,其浮点表示舍入为
72114984

您可以通过以下方式验证这一点:

x = torch.tensor(72114982, dtype=torch.long)
y = x.float() # y will actually be `72114984.0`

# this returns `True` because x is cast to a float before evaluating
x == y
> tensor(True)

# for the same reason, this returns 0.
y - x
> tensor(0.)

# this returns `False` because the tensors have different values and we don't cast to float
x == y.long()
> tensor(False)

# as longs, the difference correctly evaluates to 2
y.long() - x
> tensor(2)
© www.soinside.com 2019 - 2024. All rights reserved.