我在这里阅读了相关讨论:Using statsmodel estimations with scikit-learn cross validation, is it possible?
在链接的讨论中,建议对
statsmodels
中的模型使用包装器,以便可以使用 cross_val_score
库中的 sklearn
函数。代码确实运行了,但我不确定要提供什么参数。
示例代码:
class SMWrapper(BaseEstimator, RegressorMixin):
""" A universal sklearn-style wrapper for statsmodels regressors """
def __init__(self, model_class, fit_intercept=True):
self.model_class = model_class
self.fit_intercept = fit_intercept
def fit(self, X, y):
if self.fit_intercept:
X = sm.add_constant(X)
self.model_ = self.model_class(y, X)
self.results_ = self.model_.fit()
return self
def predict(self, X):
if self.fit_intercept:
X = sm.add_constant(X)
return self.results_.predict(X)
cross_val_score(SMWrapper(sm.GLM), X, y, scoring='r2')
我有几个关于
cross_val_score
与 statsmodels
模型一起使用的问题:
问题:
from statsmodels.formula.api import glm
)而不是使用sm.GLM
。那么我如何提供数据/变量/设计矩阵?X
和y
论据?在基于公式的方法中,我可以只提供一个完整的数据框 df
并使用 patsy 语法指定相关变量。sm.GLM
?我需要指定分发和链接功能,但我认为 SMWrapper 只接受一般模型。 kwargs=dict(family=sm.families.Gaussian())
作为 SMWrapper 的论据有效吗?完整代码示例:
import pandas as pd
import numpy as np
import statsmodels.api as sm
import random
from sklearn.model_selection import cross_val_score
from sklearn.base import BaseEstimator, RegressorMixin
# generate explanatory variables
x1 = np.random.normal(40, 4, 1000)
x2 = random.choices(["Male", "Female"], k=1000)
error = np.random.normal(0, 1, 1000)
y = 1234 + (4*x1) + error
# collect data in a dataframe
df = pd.DataFrame(zip(y, x1, x2), columns=['wage', 'workhours', 'gender'])
# treat gender as categorical
df.gender = pd.Categorical(df.gender)
df.gender = pd.get_dummies(df.gender, drop_first=True)
X = df[["workhours", "gender"]]
class SMWrapper(BaseEstimator, RegressorMixin):
""" A universal sklearn-style wrapper for statsmodels regressors """
def __init__(self, model_class, fit_intercept=True):
self.model_class = model_class
self.fit_intercept = fit_intercept
def fit(self, X, y):
if self.fit_intercept:
X = sm.add_constant(X)
self.model_ = self.model_class(y, X)
self.results_ = self.model_.fit()
return self
def predict(self, X):
if self.fit_intercept:
X = sm.add_constant(X)
return self.results_.predict(X)
cross_val_score(SMWrapper(sm.GLM), X, y, scoring='r2')