scikit学习和statsmodels - 其中R平方是正确的?

问题描述 投票:1回答:3

我想选择对未来最好的算法。我找到了一些解决方案,但我不明白其中的R平方值是正确的。

对于这一点,我分我的数据分成两个测试和训练,我印了不同的R平方值以下。

import statsmodels.api as sm
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score

lineer = LinearRegression()
lineer.fit(x_train,y_train)
lineerPredict = lineer.predict(x_test)

scoreLineer = r2_score(y_test, lineerPredict)  # First R-Squared

model = sm.OLS(lineerPredict, y_test)
print(model.fit().summary()) # Second R-Squared

第一R平方结果为-4.28。 第二R平方的结果是0.84

但我不明白它的值是正确的。

python machine-learning scikit-learn linear-regression statsmodels
3个回答
7
投票

可以说,在这种情况下,真正的挑战是要确保你比较苹果和苹果。而在你的情况,似乎你不知道。我们最好的朋友是永远的相关文件,用简单的experinets结合。所以...

虽然scikit学习的LinearRegression()(即你的第一个R平方)默认与fit_intercept=Truedocs)装,这不符合statsmodels' OLS的情况下(你的第二R平方);从docs引用:

截距不包括默认和应该由用户添加。见statsmodels.tools.add_constant

牢记这个重要的细节,让我们来运行虚拟数据一些简单的实验:

import numpy as np
import statsmodels.api as sm
from sklearn.metrics import r2_score
from sklearn.linear_model import LinearRegression

# dummy data:
y = np.array([1,3,4,5,2,3,4])
X = np.array(range(1,8)).reshape(-1,1) # reshape to column

# scikit-learn:
lr = LinearRegression()
lr.fit(X,y)
# LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None,
#     normalize=False)

lr.score(X,y)
# 0.16118421052631582

y_pred=lr.predict(X)
r2_score(y, y_pred)
# 0.16118421052631582


# statsmodels
# first artificially add intercept to X, as advised in the docs:
X_ = sm.add_constant(X)

model = sm.OLS(y,X_) # X_ here
results = model.fit()
results.rsquared
# 0.16118421052631593

对于所有实际目的,这两个值R平方所产生scikit学习和statsmodels是相同的。

让我们更进一步,并尝试scikit学习模型,而不会拦截,但在这里我们使用的人为“截获”的数据X_我们已经建立了与statsmodels使用:

lr2 = LinearRegression(fit_intercept=False)
lr2.fit(X_,y) # X_ here
# LinearRegression(copy_X=True, fit_intercept=False, n_jobs=None,
#         normalize=False)

lr2.score(X_, y)
# 0.16118421052631593

y_pred2 = lr2.predict(X_)
r2_score(y, y_pred2)
# 0.16118421052631593

同样,R平方是与以前的值相同。

所以,当我们“不小心”忘记考虑到该statsmodels OLS安装不拦截的事实会发生什么?让我们来看看:

model3 = sm.OLS(y,X) # X here, i.e. no intercept
results3 = model2.fit()
results3.rsquared
# 0.8058035714285714

那么,一个R平方0.80的确是从0.16由模型截取返回的一个非常远,可以说这正是你的情况发生。

到目前为止好,我可以很容易地在这里完成了答案;但确实是有这个地方和谐世界打破了一个观点:让我们来看看,当我们适应这两种模式,而不拦截,并在这里我们不能人为添加任何拦截的初始数据X会发生什么。我们已经安装了OLS模型之上,并得到了一个R平方为0.80;怎么样从scikit学习了类似的模式?

# scikit-learn
lr3 = LinearRegression(fit_intercept=False)
lr3.fit(X,y) # X here
lr3.score(X,y)
# -0.4309210526315792

y_pred3 = lr3.predict(X)
r2_score(y, y_pred3)
# -0.4309210526315792

噢...!有没有搞错??

似乎scikit-赚,当计算r2_score,总是假定截距,明确地在模型(fit_intercept=True)或隐式的数据(我们已经产生X_X以上,使用statsmodels' add_constant的样子);挖一点点网上显示在那里证实,情况确实像这样的Github thread(没有补救关闭)。

让我澄清一下,我上面已经描述无关,与你的问题的差异:在你的情况下,真正的问题是,你实际上是在比较苹果(截距模型)橘子(无截距模型)。


那么,为什么scikit学习,不仅在这样的(诚然边缘)的情况下会失败,但即使这样的事实在Github的问题出现它实际上是用冷漠对待? (另请注意,scikit学习核心开发谁在上线回复随便承认:“我不是超级熟悉统计” ......)。

这个答案有点出乎编码问题,如那些SO主要是关于,但它可能是值得阐述了一点在这里。

可以说,其原因是整个R平方概念来自实际上直接从世界的统计数据,其中重点是解释模型,以及它在机器学习环境,其中的重点显然是在预测模型很少使用;至少据我所知,超越一些非常入门课程,我从来没有(我的意思是从来没有...)见过预测建模问题,其中的R平方被用于任何类型的绩效考核;既不是偶然的,流行的机器学习的介绍,如安德鲁·Ng的在Coursera Machine Learning,甚至懒得提它。而且,正如在Github上线以上(强调)指出:

特别是使用测试集的时候,这是一个有点不清楚我是什么R ^ 2种手段。

与我当然同意。

正如上面所讨论的边缘情况(包括或不截距项?),我怀疑它的声音听起来真的无关紧要现代深度学习的从业者,其中截距(偏置参数)相当于是默认总是在神经网络模型包括...

看到在十字验证问题Difference between statsmodel OLS and scikit linear regression接受(和高度upvoted)答案沿着这些最后几行更详细的讨论......


3
投票

你似乎可以用sklearn.metrics_r2_score。该文件指出,“最好的得分是1.0,它可以是负的(因为该模型可以任意更糟)”

该文档导致指出“当模型拟合比水平超平面越差数据可能会发生范围为0以外R2的值,以1 Wikipedia article。这将通过施加无意义的约束时选择了错误的模型发生,或错误”。出于这个原因,你有过这样的负面r2_score事实可能比你有一个比较好的(但不是很大)R ^ 2统计中的其他方式计算更为显著。如果第一次得分表明你的模型的选择是差那么第二个统计很可能是刚刚过拟合的神器。

最终,这更是一个方法论的问题,一个编程问题。你可能会想张贴关于如何解释其中R ^ 2的两个版本不同,因此疯狂的模型Cross Validated一个后续问题。如果您确实张贴了这样的问题,请确保你给你的造型有什么多一点信息。


2
投票

当你注意,并作为the Wikipedia article指出,也有“R平方”或多重定义“R平方”。然而,常见的都有,他们的范围从01财产。他们通常是积极的,因为是从名字的“平方”部分明确。 (有关例外一般规则,请参见维基百科的文章。)

你的“第一R平方结果”是-4.28,这是不01之间,甚至不积极。因此,它是不是一个真正的“R平方”的说法。因此,用“二R平方结果”,这是在正确的范围内。

你不说你正在使用的库,所以我不能说什么你所谓的“第一R平方结果”实际上是。从现在起,当你在这里提出一个问题,请出示完整的代码片段,我们可以复制和粘贴和运行 - 要记住,所有import语句。

© www.soinside.com 2019 - 2024. All rights reserved.