如何解决“X 有 1 个特征,但 LinearRegression 需要 3 个特征作为输入。”?

问题描述 投票:0回答:1
from sklearn import linear_model
from sklearn.metrics import mean_squared_error
from math import sqrt
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import PolynomialFeatures
from sklearn.model_selection import train_test_split

def f(x):
 return np.sin(2*np.pi*x) + np.random.normal(scale=0.1, size=len(x))

NUM_SAMPLES=15
x = np.random.uniform(0,1,NUM_SAMPLES)
y = f(x)
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.33)
plt.plot(x,y,'bo')
plt.show()
poly = PolynomialFeatures(degree=2)
X_train= poly.fit_transform(np.array(X_train).reshape(-1, 1))

regr = linear_model.LinearRegression()
regr.fit(X_train, y_train)
y_pred = regr.predict(np.array(X_test).reshape(-1, 1))

# The coefficients
print('Coefficients: \n', regr.coef_)
rmse=sqrt(mean_squared_error(y_test, y_pred))
# Root mean squared error
print('Root mean squared error: %.2f' % rmse)

尝试重塑 X_train 数据和 X_test 但出现错误

X has 1 features, but LinearRegression is expecting 3 features as input.
python machine-learning scikit-learn linear-regression
1个回答
0
投票

我认为主要问题是你在通过线性回归运行之前没有将多项式变换应用于

X_test
(回归适合于transformed
X_train
)。请参阅下面的代码片段以获得工作副本:

from sklearn import linear_model
from sklearn.metrics import mean_squared_error
from math import sqrt
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import PolynomialFeatures
from sklearn.model_selection import train_test_split
def f(x):
 return np.sin(2*np.pi*x) + np.random.normal(scale=0.1, size=len(x))

NUM_SAMPLES=15
x = np.random.uniform(0,1,(NUM_SAMPLES, 1))  #Make it a column vector
y = f(x)
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.33)

plt.plot(x,y,'bo')
plt.show()
poly = PolynomialFeatures(degree=2)
X_train_transformed = poly.fit_transform(X_train)

#Now that you've fit it on X_train, use it to transform X_test:
X_test_transformed = poly.transform(X_test)

regr = linear_model.LinearRegression()
regr.fit(X_train_transformed, y_train)
y_pred = regr.predict(X_test_transformed)

# The coefficients
print('Coefficients: \n', regr.coef_)
rmse=sqrt(mean_squared_error(y_test, y_pred))
# Root mean squared error
print('Root mean squared error: %.2f' % rmse)
© www.soinside.com 2019 - 2024. All rights reserved.