pyTorch 中的简单线性回归 - 为什么损失随着每个时期的增加而增加?

问题描述 投票:0回答:2

我正在尝试使用 PyTorch 制作一个简单的线性回归模型,以根据实际温度

atemp
预测感知温度
temp

我不明白为什么这段代码会导致损失随着每个时期的增加而不是减少。所有的预测值都与事实相去甚远。

使用的样本数据

data_x = array([11.9, 12. , 13.4, 14.8, 15.8, 16.6, 16.7, 16.9, 16.9, 16.9, 16.5,
       15.7, 15.3, 15. , 15. , 14.9, 14.6, 14.2, 14.2, 14. , 13.5, 12.9,
       12.5, 12.4, 12.8, 14.3, 15.6, 16.5, 17. , 17.5, 17.7, 17.7, 17.8,
       17.5, 16.9, 15.6, 14. , 12.2, 11. , 10.6, 10.6, 10.7, 10.9, 10.6,
       10.3,  9.4,  8.7,  7.8,  8.1, 11. , 13.4, 15.2, 16.5, 17.4, 18.1,
       18.5, 18.7, 18.6, 17.7, 16. , 14.6, 13.8, 13. , 12.5, 12. , 11.8,
       11.5, 11.3, 10.9, 10.6, 10.2,  9.9, 10.5, 13.1, 15.3, 17.2, 18.9,
       20.3, 21.2, 21.8, 21.9, 21.5, 20.2, 18.3, 16.8, 15.8, 14.9, 14.2,
       13.6, 13.2, 12.9, 12.7, 12.6, 12.6, 12.6, 12.8, 13.4, 15.5, 17.6,
       19.3])
data_y = array([ 8.9,  9.3, 10.7, 12.1, 13.1, 13.8, 14. , 14.1, 14.3, 14.5, 14.3,
       13.7, 13.2, 12.7, 12.7, 12.5, 11.9, 11.7, 11.7, 11.5, 11.1, 10.6,
       10.3, 10.2, 10.9, 12.5, 12.8, 13.8, 14.6, 14.9, 14.9, 15.1, 15.5,
       15.6, 15.8, 14.7, 13.1, 11.2,  9.6,  9.1,  9.4,  9.7,  9.9,  9.6,
        9.2,  8. ,  7.1,  6.1,  6.5, 10.2, 12.7, 14.3, 15.5, 16.6, 17.4,
       17.7, 17.8, 17.6, 17.2, 15.3, 13.4, 12.4, 11.5, 10.8, 10.1, 10. ,
        9.8,  9.6,  9.3,  9. ,  8.5,  8.1,  8.8, 12. , 14.4, 16.6, 18.5,
       20.1, 21. , 21.3, 21.2, 21.2, 20.1, 17.9, 16.1, 14.6, 13.8, 13.1,
       12.3, 11.8, 11.6, 11.4, 11.3, 11.3, 11.3, 11.4, 12. , 14.6, 16.8,
       18.8])

绘图数据:

Plotted data

代码

# import data from CSV to pandas Dataframe
bg = pd.read_csv('data.csv')
X_pandas = bg['temp']
y_pandas = bg['atemp']

# covert to tensors
data_x = X_pandas.head(100).values
data_y = y_pandas.head(100).values
X = torch.tensor(data_x, dtype=torch.float32).reshape(-1, 1)
y = torch.tensor(data_y, dtype=torch.float32).reshape(-1, 1)

# create the model
model = nn.Linear(1, 1)
loss_fn = nn.MSELoss()  # mean square error
optimizer = optim.SGD(model.parameters(), lr=0.01)

# train the model
n_epochs = 40   # number of epochs to run
for epoch in range(n_epochs):
    # forward pass
    y_pred = model(X)
    # compute loss
    loss = loss_fn(y_pred, y)
    # backward pass
    loss.backward()
    # update parameters
    optimizer.step()
    # zero gradients
    optimizer.zero_grad()
    # print loss
    print(f'epoch: {epoch + 1}, loss = {loss.item():.4f}')

# display the predicted values
predicted = model(X).detach().numpy()
display(predicted)

输出

epoch: 1, loss = 16.5762
epoch: 2, loss = 191.0379
epoch: 3, loss = 2291.5081
epoch: 4, loss = 27580.5195
epoch: 5, loss = 332052.6875
epoch: 6, loss = 3997804.2500
epoch: 7, loss = 48132328.0000
epoch: 8, loss = 579498624.0000
epoch: 9, loss = 6976988160.0000
epoch: 10, loss = 84000866304.0000
epoch: 11, loss = 1011344670720.0000
epoch: 12, loss = 12176279470080.0000
epoch: 13, loss = 146598776537088.0000
epoch: 14, loss = 1765004462260224.0000
epoch: 15, loss = 21250117348622336.0000
epoch: 16, loss = 255844948350337024.0000
epoch: 17, loss = 3080297218377252864.0000
epoch: 18, loss = 37085819119396192256.0000
epoch: 19, loss = 446502312996857970688.0000
epoch: 20, loss = 5375748153858603352064.0000
epoch: 21, loss = 64722396677244886974464.0000
epoch: 22, loss = 779237667397586303057920.0000
epoch: 23, loss = 9381773651754967424303104.0000
epoch: 24, loss = 112953739724808869434621952.0000
epoch: 25, loss = 1359928800566679308764971008.0000
epoch: 26, loss = 16373128158657455337028714496.0000
epoch: 27, loss = 197127444146361433227589058560.0000
epoch: 28, loss = 2373354706586702693378941779968.0000
epoch: 29, loss = 28574463232459721913615454830592.0000
epoch: 30, loss = 344027831021918449557295178186752.0000
epoch: 31, loss = 4141990153063893156517557464727552.0000
epoch: 32, loss = 49868270370463502095675094080684032.0000
epoch: 33, loss = 600398977963427833849804206813216768.0000
epoch: 34, loss = inf
epoch: 35, loss = inf
epoch: 36, loss = inf
epoch: 37, loss = inf
epoch: 38, loss = inf
epoch: 39, loss = inf
epoch: 40, loss = inf

预测值:

array([[1.60481241e+21],
       [1.61822441e+21],
       [1.80599158e+21],
       [1.99375890e+21],
       [2.12787834e+21],
       [2.23517393e+21],
       [2.24858593e+21],
       [2.27540965e+21],
       [2.27540965e+21],
       [2.27540965e+21],
       ...

这个奇怪的结果可能是什么原因?

python machine-learning deep-learning pytorch linear-regression
2个回答
1
投票

我的问题似乎是 0.01 的学习率对于这个问题和数据量来说太高了。

更改此位解决了问题:

optimizer = optim.SGD(model.parameters(), lr=0.01)

optimizer = optim.SGD(model.parameters(), lr=0.005)


0
投票

在不改变 lr

的情况下缩放会有所帮助
# X_pandas = bg['temp']
# y_pandas = bg['atemp']

data_x = data_x/data_x.max()
data_y = data_y

# covert to tensors
# data_x = X_pandas.head(100).values
# data_y = y_pandas.head(100).values
X = torch.tensor(data_x, dtype=torch.float32).reshape(-1, 1)
y = torch.tensor(data_y, dtype=torch.float32).reshape(-1, 1)

# create the model
model = nn.Linear(1, 1)
loss_fn = nn.MSELoss()  # mean square error
optimizer = optim.SGD(model.parameters(), lr=0.01)

# train the model
n_epochs = 40   # number of epochs to run
for epoch in range(n_epochs):
    # forward pass
    y_pred = model(X)
    # compute loss
    loss = loss_fn(y_pred, y)
    # backward pass
    loss.backward()
    # update parameters
    optimizer.step()
    # zero gradients
    optimizer.zero_grad()
    # print loss
    print(f'epoch: {epoch + 1}, loss = {loss.item():.4f}')

# display the predicted values
predicted = model(X).detach().numpy()
display(predicted)

epoch: 1, loss = 210.4702
epoch: 2, loss = 198.8098
epoch: 3, loss = 187.8156
epoch: 4, loss = 177.4496
epoch: 5, loss = 167.6758
epoch: 6, loss = 158.4604
epoch: 7, loss = 149.7714
epoch: 8, loss = 141.5789
epoch: 9, loss = 133.8544
epoch: 10, loss = 126.5711
epoch: 11, loss = 119.7039
epoch: 12, loss = 113.2290
epoch: 13, loss = 107.1239
epoch: 14, loss = 101.3676
epoch: 15, loss = 95.9400
epoch: 16, loss = 90.8225
epoch: 17, loss = 85.9972
epoch: 18, loss = 81.4475
epoch: 19, loss = 77.1577
epoch: 20, loss = 73.1128
epoch: 21, loss = 69.2989
epoch: 22, loss = 65.7028
epoch: 23, loss = 62.3120
epoch: 24, loss = 59.1148
epoch: 25, loss = 56.1002
epoch: 26, loss = 53.2576
epoch: 27, loss = 50.5773
epoch: 28, loss = 48.0500
epoch: 29, loss = 45.6670
epoch: 30, loss = 43.4200
epoch: 31, loss = 41.3012
epoch: 32, loss = 39.3033
epoch: 33, loss = 37.4193
epoch: 34, loss = 35.6429
epoch: 35, loss = 33.9678
epoch: 36, loss = 32.3883
epoch: 37, loss = 30.8988
epoch: 38, loss = 29.4943
epoch: 39, loss = 28.1699
epoch: 40, loss = 26.9209
array([[ 8.267589 ],
       [ 8.287247 ],
       [ 8.562464 ],
       [ 8.837682 ],
       [ 9.0342655],
       [ 9.191533 ],
       [ 9.211191 ],
       [ 9.250508 ],
       [ 9.250508 ],
       [ 9.250508 ],
       [ 9.171875 ],
       [ 9.014607 ],
       [ 8.935974 ],
       [ 8.876999 ],
       [ 8.876999 ],
       [ 8.85734  ],
       [ 8.798365 ],
       [ 8.719731 ],
       [ 8.719731 ],
       [ 8.680414 ],
       [ 8.582123 ],
       [ 8.464172 ],
       [ 8.385539 ],
       [ 8.36588  ],
       [ 8.444513 ],
       [ 8.739389 ],
       [ 8.994949 ],
       [ 9.171875 ],
       [ 9.270166 ],
       [ 9.368459 ],
       [ 9.407776 ],
       [ 9.407776 ],
       [ 9.427434 ],
       [ 9.368459 ],
       [ 9.250508 ],
       [ 8.994949 ],
       [ 8.680414 ],
       [ 8.326564 ],
       [ 8.090663 ],
       [ 8.012029 ],
       [ 8.012029 ],
       [ 8.031688 ],
       [ 8.071004 ],
       [ 8.012029 ],
       [ 7.9530535],
       [ 7.7761283],
       [ 7.6385193],
       [ 7.4615936],
       [ 7.520569 ],
       [ 8.090663 ],
       [ 8.562464 ],
       [ 8.916315 ],
       [ 9.171875 ],
       [ 9.348801 ],
       [ 9.486409 ],
       [ 9.5650425],
       [ 9.60436  ],
       [ 9.584702 ],
       [ 9.407776 ],
       [ 9.073583 ],
       [ 8.798365 ],
       [ 8.641098 ],
       [ 8.48383  ],
       [ 8.385539 ],
       [ 8.287247 ],
       [ 8.24793  ],
       [ 8.188955 ],
       [ 8.149638 ],
       [ 8.071004 ],
       [ 8.012029 ],
       [ 7.9333954],
       [ 7.87442  ],
       [ 7.9923706],
       [ 8.503489 ],
       [ 8.935974 ],
       [ 9.309484 ],
       [ 9.643677 ],
       [ 9.918894 ],
       [10.095819 ],
       [10.21377  ],
       [10.233429 ],
       [10.154795 ],
       [ 9.899236 ],
       [ 9.525726 ],
       [ 9.23085  ],
       [ 9.0342655],
       [ 8.85734  ],
       [ 8.719731 ],
       [ 8.601781 ],
       [ 8.523148 ],
       [ 8.464172 ],
       [ 8.424855 ],
       [ 8.405197 ],
       [ 8.405197 ],
       [ 8.405197 ],
       [ 8.444513 ],
       [ 8.562464 ],
       [ 8.97529  ],
       [ 9.388117 ],
   [ 9.72231  ]], dtype=float32)
© www.soinside.com 2019 - 2024. All rights reserved.