我是 Pytorch 和 Pytorch-Lightning 的新手,所以我从一些基本的东西开始检查我没有做错任何事。
首先我编了一些数据:
from sklearn.datasets import make_regression
x, y = make_regression(n_samples=100, n_features=4, n_informative=2, bias=False)
然后我尝试使用 Pytorch 和 Pytorch-Lightning 做一些简单的线性回归
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import mean_squared_error
class LinearRegression(pl.LightningModule):
def __init__(self):
super().__init__()
self.lin = nn.Linear(4, 1, bias=False)
def forward(self, x):
return self.lin(x)
def training_step(self, batch, batch_idx):
x, y = batch
x_hat = self(x)
loss = F.mse_loss(x_hat, y)
return loss
def predict_step(self, batch, batch_idx, dataloader_idx=0):
x, y = batch
return self(x)
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=1e-5)
return optimizer
dataset = TensorDataset(torch.Tensor(x), torch.Tensor(y))
dataloader = DataLoader(dataset, batch_size=32)
model = LinearRegression()
trainer = pl.Trainer(max_epochs=300)
trainer.fit(model=model, train_dataloaders=dataloader)
preds = trainer.predict(model, dataloader)
print(f"MSE: {mean_squared_error(torch.concat(preds), y)}")
对于这样一个简单的问题,这给了我一个疯狂的高错误。我试着玩弄时代的数量或学习率,但什么也做不了。
用 scikit-learn 线性回归做同样的事情给我 0 个错误。
from sklearn.linear_model import LinearRegression
linear = LinearRegression(fit_intercept=False)
linear.fit(x, y)
print(f"MSE: {mean_squared_error(linear.predict(x), y)}")
我猜这不是模型,而是我在使用 pytorch 或/和 Pytorch-lightning 时错过的东西。