为了简单起见,我想使用此代码将火炬模型的所有参数设置为常量
72114982
model = Net()
params = model.state_dict()
for k, v in params.items():
params[k] = torch.full(v.shape, 72114982, dtype=torch.long)
model.load_state_dict(params)
print(model.state_dict().values())
然后 print 语句显示所有值实际上都设置为
72114984
,与我最初想要的值相差 2。
为了简单起见,定义
Net
如下
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(2, 2, 2)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(2, 2)
这是数据类型的问题。
模型参数被转换为浮点张量。
72114984
足够大,其浮点表示舍入为 72114984
。
您可以通过以下方式验证这一点:
x = torch.tensor(72114982, dtype=torch.long)
y = x.float() # y will actually be `72114984.0`
# this returns `True` because x is cast to a float before evaluating
x == y
> tensor(True)
# for the same reason, this returns 0.
y - x
> tensor(0.)
# this returns `False` because the tensors have different values and we don't cast to float
x == y.long()
> tensor(False)
# as longs, the difference correctly evaluates to 2
y.long() - x
> tensor(2)