用Statsmodels进行简单的逻辑回归。添加截距并使逻辑回归方程可视化

问题描述 投票:1回答:1

使用Statsmodels,我试图生成一个简单的逻辑回归模型,根据一个人的身高(Hgt)来预测他是否吸烟(Smoke)。

我觉得需要在逻辑回归模型中加入截距,但我不知道如何使用add_constant()函数来实现。另外,我也不知道为什么会产生下面的错误。

这是数据集,Pulse.CSV。https:/drive.google.comfiled1FdUK9p4Dub4NXsc-zHrYI-AGEEBkX98Vview?usp=sharing。

完整的代码和输出都在这个PDF文件中。https:/drive.google.comfiled1kHlrAjiU7QvFXF2a7tlTSFPgfpq9bOXJview?usp=sharing。

import numpy as np
import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
raw_data = pd.read_csv('Pulse.csv')
raw_data
x1 = raw_data['Hgt']
y = raw_data['Smoke'] 
reg_log = sm.Logit(y,x1,missing='Drop')
results_log = reg_log.fit()
def f(x,b0,b1):
    return np.array(np.exp(b0+x*b1) / (1 + np.exp(b0+x*b1)))
f_sorted = np.sort(f(x1,results_log.params[0],results_log.params[1]))
x_sorted = np.sort(np.array(x1))
plt.scatter(x1,y,color='C0')
plt.xlabel('Hgt', fontsize = 20)
plt.ylabel('Smoked', fontsize = 20)
plt.plot(x_sorted,f_sorted,color='C8')
plt.show()

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~/opt/anaconda3/lib/python3.7/site-packages/pandas/core/indexes/base.py in get_value(self, series, key)
   4729         try:
-> 4730             return self._engine.get_value(s, k, tz=getattr(series.dtype, "tz", None))
   4731         except KeyError as e1:
((( Truncated for brevity )))
IndexError: index out of bounds
data-visualization linear-regression constants statsmodels traceback
1个回答
1
投票

截距 中默认不添加 Statsmodels 回归,但如果你需要,你可以手动包含它。

import numpy as np
import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
raw_data = pd.read_csv('Pulse.csv')
raw_data
x1 = raw_data['Hgt']
y = raw_data['Smoke'] 

x1 = sm.add_constant(x1)

reg_log = sm.Logit(y,x1,missing='Drop')
results_log = reg_log.fit()

results_log.summary()

def f(x,b0,b1):
    return np.array(np.exp(b0+x*b1) / (1 + np.exp(b0+x*b1)))
f_sorted = np.sort(f(x1,results_log.params[0],results_log.params[1]))
x_sorted = np.sort(np.array(x1))

plt.scatter(x1['Hgt'],y,color='C0')

plt.xlabel('Hgt', fontsize = 20)
plt.ylabel('Smoked', fontsize = 20)
plt.plot(x_sorted,f_sorted,color='C8')
plt.show()

这也将解决这个错误,因为在你的初始代码中没有截获。源代码

© www.soinside.com 2019 - 2024. All rights reserved.