尝试使用高斯 GLM 来简单匹配 OLS,但要么非常错误,要么具有完美的分离

问题描述 投票:0回答:1

我正在尝试让 GLM 做 OLS 所做的事情,只是为了对 GLM 有一个基本的了解。但它似乎并没有达到我想要的效果。考虑这段代码:

import numpy as np
import statsmodels.api as sm
from scipy import stats

print("### This is dummy data that clearly is y = 0.25 * x + 5: #############")
aInput = np.arange(10)
print(aInput)
aLinear = aInput.copy() * 0.25 + 5
print(aLinear)

print("### This is OLS to show clearly what we're after: ####################")
aInputConst = sm.add_constant(aInput)
model = sm.OLS(aLinear, aInputConst)
results = model.fit()
print(results.params)

print("This is GLM which looks nothing like what I expect: ##################")
model = sm.GLM(aLinear, aInput, family=sm.families.Gaussian())
result = model.fit()
y_hat = result.predict(aInput)
print(y_hat)

print("This is GLM with the constant, but it just fails: ####################")
#                       vvvvvvvvvvv
model = sm.GLM(aLinear, aInputConst, family=sm.families.Gaussian())
result = model.fit()
y_hat = result.predict(aInput)
print(y_hat)

现在考虑输出:

### This is dummy data that clearly is y = 0.25 * x + 5: #############
[0 1 2 3 4 5 6 7 8 9]
[5.   5.25 5.5  5.75 6.   6.25 6.5  6.75 7.   7.25]
### This is OLS to show clearly what we're after: ####################
[5.   0.25]
This is GLM which looks nothing like what I expect: ##################
[0.         1.03947368 2.07894737 3.11842105 4.15789474 5.19736842
 6.23684211 7.27631579 8.31578947 9.35526316]
This is GLM with the constant, but it just fails: ####################
Traceback (most recent call last):
  File "min.py", line 26, in <module>
    result = model.fit()
  File "/export/home/jm43436e/.local/lib/python3.6/site-packages/statsmodels/genmod/generalized_linear_model.py", line 1065, in fit
    cov_kwds=cov_kwds, use_t=use_t, **kwargs)
  File "/export/home/jm43436e/.local/lib/python3.6/site-packages/statsmodels/genmod/generalized_linear_model.py", line 1211, in _fit_irls
    raise PerfectSeparationError(msg)
statsmodels.tools.sm_exceptions.PerfectSeparationError: Perfect separation detected, results not available

观察:

  • 数据被设计为 y = 0.25 x + 5,这样当我们看到它时我们就知道“正确”的答案。
  • OLS 自然地发现了这一点。它找到 5 并找到 0.25。够简单的。
  • 但是 GLM 的第一次尝试只是从大约 0 到 10,而我预计它会从大约 5 到大约 7.25,这就是原始输出。为什么没有呢?现在,我读到的一件事是您需要添加常量,这会导致下一个问题:
  • 当我添加常量时,它会出现“完美分离”错误。如果我应该添加该常量,我怎样才能添加它而不出现错误?

我的问题是:

  • 我怎样才能让 GLM 表现得像 OLS 并告诉我 5 和 0.25。我只是想以此作为起始基线,但做不到。这要么是通过第一个 GLM 调用来获得 b1 和 b0,以及正确的输出范围;或者如果我应该添加常量,然后提供 b1 和 b0。

我显然很困惑。感谢所有帮助!

python statsmodels glm
1个回答
0
投票

明白了!两件事:

  • 是的,您确实需要“sm.add_constant()”调用。正如我阅读的许多网站上所指出的,这是 StatsModels 和 Sklearn 的奇怪行为,但如果您想要截距值,则需要一个常量输入变量。

  • 问题是虚拟日期太好了。通过向数据添加一些噪声,使其围绕回归线抖动,它就可以正常工作。显然,如果数据要完善,机器就很难拟合。我不知道为什么。只是我的经验观察。

带有更改注释的工作代码是:

import numpy as np
import statsmodels.api as sm
from scipy import stats

print("### This is dummy data that clearly is y = 0.25 * x + 5: #############")
aInput = np.arange(10)
print(aInput)
noise = (np.random.rand(10) - 0.5) / 10 + 1 # New.
aLinear = aInput.copy() * 0.25 * noise + 5 # Changed: add noise.
print(aLinear)

print("### This is OLS to show clearly what we're after: ####################")
aInputConst = sm.add_constant(aInput)
model = sm.OLS(aLinear, aInputConst)
results = model.fit()
print(results.params)

#   print("This is GLM which looks nothing like what I expect: ##################")
#   model = sm.GLM(aLinear, aInput, family=sm.families.Gaussian())
#   result = model.fit()
#   y_hat = result.predict(aInput)
#   print(y_hat)

print("This is GLM with the constant, but it just fails: ####################")
#                       vvvvvvvvvvv
model = sm.GLM(aLinear, aInputConst, family=sm.families.Gaussian())
result = model.fit()
y_hat = result.predict(aInputConst) # Changed: use the "Const" version.
print(y_hat)

输出为:

### This is dummy data that clearly is y = 0.25 * x + 5: #############
[0 1 2 3 4 5 6 7 8 9]
[5.         5.25190469 5.48057897 5.74306435 6.00161857 6.20957869
 6.43812791 6.71636422 7.09118653 7.25882849]
### This is OLS to show clearly what we're after: ####################
[4.98249325 0.25258489]
This is GLM with the constant, but it just fails: ####################
[4.98249325 5.23507814 5.48766302 5.74024791 5.9928328  6.24541768
 6.49800257 6.75058746 7.00317235 7.25575723]

如前所述,我所追求的一件事是查看数据中人为设计的系数,果然,它们的结果是正确的:

[4.98249325 0.25258489]

所以是的,您需要常量列。如果您的测试数据太干净,请稍微抖动它,使其更像混乱的现实世界数据。

© www.soinside.com 2019 - 2024. All rights reserved.