Python中线性代数的线性回归

问题描述 投票:0回答:1

我在维基百科(https://en.wikipedia.org/wiki/Coefficient_of_determination)上解释这些公式错误吗?以下是我的尝试。

ssres

def ss_res(X, y, theta):

    y_diff=[]
    y_pred = X.dot(theta)

    for i in range(0, len(y)):
        y_diff.append((y[i]-y_pred[i])**2)

    return np.sum(y_diff)

输出看起来正确,但数字稍微偏离......就像几个小数点。

stderror

def std_error(X, y, theta):


    delta = (1/(len(y)-X.shape[1]+1))*(ss_res(X,y,theta))
    matrix1=matrix_power((X.T.dot(X)),-1)
    thing2=delta*matrix1
    thing3=scipy.linalg.sqrtm(thing2)

    res=np.diag(thing3)
    serr=np.reshape(res, (6, 1))
    return serr

std_error_array=std_error(X,y,theta)

python linear-regression linear-algebra
1个回答
2
投票

您可能会或可能不想要+1所谓的delta,取决于您的X是否包含“常量”列(即所有值= 1)

如果有点非Pythonic,它看起来还不错。我很想把它们写成:

import numpy as np
from numpy.linalg import inv
from scipy.linalg import sqrtm

def solve_theta(X, Y):
    return np.linalg.solve(X.T @ X, X.T @ Y)

def ss_res(X, Y, theta):
    res = Y - (X @ theta)
    return np.sum(res ** 2)

def std_error(X, Y, theta):
    nr, rank = X.shape
    resid_df = nr - rank
    residvar = ss_res(X, Y, theta) / resid_df
    var_theta = residvar * inv(X.T @ X)
    return np.diag(sqrtm(var_theta))[:,None]

注意:这使用Python 3.5 style matrix multiply operator @而不是写出.dot()

这种算法的数值稳定性并不令人惊讶,您可能需要考虑使用SVD或QR分解。有一个平易近人的描述如何使用SVD进行:

John Mandel(1982)“在回归分析中使用奇异值分解”10.1080/00031305.1982.10482771

我们可以通过创建一些虚拟数据来测试它:

np.random.seed(42)

N = 20
K = 3

true_theta = np.random.randn(K, 1) * 5
X = np.random.randn(N, K)
Y = np.random.randn(N, 1) + X @ true_theta

并在其上运行上面的代码:

theta = solve_theta(X, Y)
sse = std_error(X, Y, theta)

print(np.column_stack((theta, sse)))

这使:

[[ 2.23556391  0.35678574]
 [-0.40643163  0.24751913]
 [ 3.14687637  0.26461827]]

我们可以用statsmodels来测试这个:

import statsmodels.api as sm

sm.OLS(Y, X).fit().summary()

这使:

                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
x1             2.2356      0.358      6.243      0.000       1.480       2.991
x2            -0.4064      0.248     -1.641      0.119      -0.929       0.116
x3             3.1469      0.266     11.812      0.000       2.585       3.709

这非常接近。

© www.soinside.com 2019 - 2024. All rights reserved.