现在可与 PyTorch 一起使用的简单线性回归

问题描述 投票:0回答:0

我是 Pytorch 和 Pytorch-Lightning 的新手,所以我从一些基本的东西开始检查我没有做错任何事。

首先我编了一些数据:

from sklearn.datasets import make_regression
x, y = make_regression(n_samples=100, n_features=4, n_informative=2, bias=False)

然后我尝试使用 Pytorch 和 Pytorch-Lightning 做一些简单的线性回归

import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import mean_squared_error

class LinearRegression(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(4, 1, bias=False)

    def forward(self, x):
        return self.lin(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        x_hat = self(x)
        loss = F.mse_loss(x_hat, y)
        return loss

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x, y = batch
        return self(x)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=1e-5)
        return optimizer

dataset = TensorDataset(torch.Tensor(x), torch.Tensor(y))
dataloader = DataLoader(dataset, batch_size=32)

model = LinearRegression()
trainer = pl.Trainer(max_epochs=300)
trainer.fit(model=model, train_dataloaders=dataloader)
preds = trainer.predict(model, dataloader)
print(f"MSE: {mean_squared_error(torch.concat(preds), y)}")

对于这样一个简单的问题,这给了我一个疯狂的高错误。我试着玩弄时代的数量或学习率,但什么也做不了。

用 scikit-learn 线性回归做同样的事情给我 0 个错误。

from sklearn.linear_model import LinearRegression
linear = LinearRegression(fit_intercept=False)
linear.fit(x, y)
print(f"MSE: {mean_squared_error(linear.predict(x), y)}")

我猜这不是模型,而是我在使用 pytorch 或/和 Pytorch-lightning 时错过的东西。

pytorch linear-regression pytorch-lightning
© www.soinside.com 2019 - 2024. All rights reserved.