我希望我的神经网络能够解决多项式回归问题,例如 y=(x*x) + 2x -3。
所以现在我创建了一个具有 1 个输入节点、100 个隐藏节点和 1 个输出节点的网络,并给了它很多时期来使用高测试数据量进行训练。问题是,大约 20000 个 epoch 后的预测还可以,但比训练后的线性回归预测要差得多。
import torch
from torch import Tensor
from torch.nn import Linear, MSELoss, functional as F
from torch.optim import SGD, Adam, RMSprop
from torch.autograd import Variable
import numpy as np
# define our data generation function
def data_generator(data_size=1000):
# f(x) = y = x^2 + 4x - 3
inputs = []
labels = []
# loop data_size times to generate the data
for ix in range(data_size):
# generate a random number between 0 and 1000
x = np.random.randint(1000) / 1000
# calculate the y value using the function x^2 + 4x - 3
y = (x * x) + (4 * x) - 3
# append the values to our input and labels lists
inputs.append([x])
labels.append([y])
return inputs, labels
# define the model
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = Linear(1, 100)
self.fc2 = Linear(100, 1)
def forward(self, x):
x = F.relu(self.fc1(x)
x = self.fc2(x)
return x
model = Net()
# define the loss function
critereon = MSELoss()
# define the optimizer
optimizer = SGD(model.parameters(), lr=0.01)
# define the number of epochs and the data set size
nb_epochs = 20000
data_size = 1000
# create our training loop
for epoch in range(nb_epochs):
X, y = data_generator(data_size)
X = Variable(Tensor(X))
y = Variable(Tensor(y))
epoch_loss = 0;
y_pred = model(X)
loss = critereon(y_pred, y)
epoch_loss = loss.data
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("Epoch: {} Loss: {}".format(epoch, epoch_loss))
# test the model
model.eval()
test_data = data_generator(1)
prediction = model(Variable(Tensor(test_data[0][0])))
print("Prediction: {}".format(prediction.data[0]))
print("Expected: {}".format(test_data[1][0]))
他们是获得更好结果的方法吗?我想知道是否应该尝试获得 3 个输出,将它们称为 a、b 和 c,这样 y= a(x*x)+b(x)+c。但我不知道如何实现它并训练我的神经网络。
对于这个问题,如果您将具有 1 个
Net()
层的 Linear
视为具有包括 Linear Regression
在内的输入特征的 [x^2, x]
,可能会更容易。
import torch
from torch import Tensor
from torch.nn import Linear, MSELoss, functional as F
from torch.optim import SGD, Adam, RMSprop
from torch.autograd import Variable
import numpy as np
# define our data generation function
def data_generator(data_size=1000):
# f(x) = y = x^2 + 4x - 3
inputs = []
labels = []
# loop data_size times to generate the data
for ix in range(data_size):
# generate a random number between 0 and 1000
x = np.random.randint(2000) / 1000 # I edited here for you
# calculate the y value using the function x^2 + 4x - 3
y = (x * x) + (4 * x) - 3
# append the values to our input and labels lists
inputs.append([x*x, x])
labels.append([y])
return inputs, labels
# define the model
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = Linear(2, 1)
def forward(self, x):
return self.fc1(x)
model = Net()
Epoch: 0 Loss: 33.75775909423828
Epoch: 1000 Loss: 0.00046704441774636507
Epoch: 2000 Loss: 9.437128483114066e-07
Epoch: 3000 Loss: 2.0870876138445738e-09
Epoch: 4000 Loss: 1.126847400112485e-11
Prediction: 5.355223655700684
Expected: [5.355224999999999]
你要找的系数
a
、b
、c
实际上就是self.fc1
的权重和偏差:
print('a & b:', model.fc1.weight)
print('c:', model.fc1.bias)
# Output
a & b: Parameter containing:
tensor([[1.0000, 4.0000]], requires_grad=True)
c: Parameter containing:
tensor([-3.0000], requires_grad=True)
仅在 5000 个 epoch 内,全部收敛:
a
-> 1、b
-> 4 和 c
-> -3。
该模型非常轻量级,只有 3 个参数,而不是:
(100 + 1) + (100 + 1) = 202 parameters in the old model
希望这对您有帮助!
如果您知道近似系数,那么梯度下降是一种非常低效的拟合多项式的方法,正如已接受的答案中所建议的那样。 您可以使用 torch.linalg.lstsq 或 numpy 等效项 numpy.linalg.lstsq 获得直接解析解。
import torch
import matplotlib.pyplot as plt
## We won't need any gradients here
torch.set_grad_enabled(False)
# Generate some test data
x = torch.arange(-1, 2, 0.01)
# Define our test function
y = (x * x) + (0.5 * x) - 1.1
# Add some noise to the data
y += torch.randn_like(x) * 0.5
# Create some features
X = torch.stack([
x,
x**2,
torch.ones_like(x) # bias
]).T
# Fit to data
W = torch.linalg.lstsq(X, y.unsqueeze(-1)).solution
pred = (X @ W)
plt.scatter(x, y)
plt.plot(x, pred, c='orange', linewidth=5, alpha=0.8)
plt.show()