为什么我的最小二乘法不适合数据点

问题描述 投票:0回答:1

我得到的二维平面不太适合这样的代码:

import numpy as np
from scipy import linalg as la
from scipy.linalg import solve

# data
f1 = np.array([1., 1.5, 3.5, 4.])
f2 =  np.array([3., 4., 7., 7.25])
# z = np.array([6., 6.5, 8., 9.])
A= X= np.array([f1, f2]).T

b= y= np.array([0.5, 1., 1.5, 2.]).T

##################### la.lstsq

res= la.lstsq(A,b)[0]
print(res)    
##################### custom lu

#custom OLS 
def ord_ls(X, y):
    A = X.T @ X
    b = X.T @ y
    beta = solve(A, b, overwrite_a=True, overwrite_b=True,
                 check_finite=True)
    return beta

res = ord_ls(X, y)
print(res)

##################### plot

# use the optimized parameters to plot the fitted curve in 3D space.
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Create 3D plot of the data points and the fitted curve
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(f1, f2, y, color='blue')
x_range = np.linspace(0, 7, 100)
y_range = np.linspace(0, 7,100)

X, Y = np.meshgrid(x_range, y_range)
Z = res[0]*X + res[1]

ax.plot_surface(X, Y, Z, color='red', alpha=0.5)
ax.set_xlabel('feat.1')
ax.set_ylabel('feat.2')
ax.set_zlabel('target')
plt.show()

# [0.2961165  0.09475728]
# [0.2961165  0.09475728]

尽管系数似乎相同,但情节仍然扭曲。有什么解释或者更正吗?或者需要一些正则化,比如最小二乘法?或者两个特征共线&这就是原因? (我对 Linalg 还不是很熟悉)

附注scipy-0.18.0 文档

python scipy least-squares matrix-factorization
1个回答
0
投票

重命名了主题,感谢贾里德,

似乎可以使用这样的代码(从我的评论中转发):

import numpy as np
from scipy import linalg as la
from scipy.linalg import solve

# data
f1 = np.array([1., 1.5, 3.5, 4.])
f2 =  np.array([3., 4., 7., 7.25])
# z = np.array([6., 6.5, 8., 9.])
A= X= np.array([f1, f2]).T
b= y= np.array([0.5, 1., 1.5, 2.]).T

X = np.column_stack((np.ones(len(y)),X))    # CORRECTION

##################### la.lstsq

res= la.lstsq(X,b)[0]
print(res)

#####################
##from scipy.linalg import lu_factor, lu_solve
##
##lu, piv = lu_factor(A)  # ValueError: expected square matrix
##x = lu_solve((lu, piv), b)
##print(x)

##################### custom ls

#custom OLS
def ord_ls(X, y):
    A = X.T @ X
    b = X.T @ y
    beta = solve(A, b, overwrite_a=True, overwrite_b=True,
                 check_finite=True)
    return beta

res = ord_ls(X, y)
print(res)

##################### plot

# use the optimized parameters to plot the fitted curve in 3D space.
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Create 3D plot of the data points and the fitted curve
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(f1, f2, y, color='blue')
x_range = np.linspace(0, 7, 100)
y_range = np.linspace(0, 7,100)

X, Y = np.meshgrid(x_range, y_range)
Z = res[0]*X + res[1]

ax.plot_surface(X, Y, Z, color='red', alpha=0.5)
ax.set_xlabel('feat.1')
ax.set_ylabel('feat.2')
ax.set_zlabel('target')
plt.show()
© www.soinside.com 2019 - 2024. All rights reserved.