如何使用 statsmodels 中常用的模型包装器来应用交叉验证?

问题描述 投票:0回答:0

我在这里阅读了相关讨论:Using statsmodel estimations with scikit-learn cross validation, is it possible?

在链接的讨论中,建议对

statsmodels
中的模型使用包装器,以便可以使用
cross_val_score
库中的
sklearn
函数。代码确实运行了,但我不确定要提供什么参数。

示例代码:

class SMWrapper(BaseEstimator, RegressorMixin):
    """ A universal sklearn-style wrapper for statsmodels regressors """
    def __init__(self, model_class, fit_intercept=True):
        self.model_class = model_class
        self.fit_intercept = fit_intercept
    def fit(self, X, y):
        if self.fit_intercept:
            X = sm.add_constant(X)
        self.model_ = self.model_class(y, X)
        self.results_ = self.model_.fit()
        return self
    def predict(self, X):
        if self.fit_intercept:
            X = sm.add_constant(X)
        return self.results_.predict(X)


cross_val_score(SMWrapper(sm.GLM), X, y, scoring='r2')

我有几个关于

cross_val_score
statsmodels
模型一起使用的问题:

问题:

  • 我可以使用基于公式的版本吗?也就是说,我想直接导入glm(
    from statsmodels.formula.api import glm
    )而不是使用
    sm.GLM
    。那么我如何提供数据/变量/设计矩阵?
  • 提供什么
    X
    y
    论据?在基于公式的方法中,我可以只提供一个完整的数据框
    df
    并使用 patsy 语法指定相关变量。
  • 如何论证
    sm.GLM
    ?我需要指定分发和链接功能,但我认为 SMWrapper 只接受一般模型。
    kwargs=dict(family=sm.families.Gaussian())
    作为 SMWrapper 的论据有效吗?

完整代码示例:

import pandas as pd
import numpy as np
import statsmodels.api as sm
import random
from sklearn.model_selection import cross_val_score
from sklearn.base import BaseEstimator, RegressorMixin

# generate explanatory variables
x1 = np.random.normal(40, 4, 1000)
x2 = random.choices(["Male", "Female"], k=1000)
error = np.random.normal(0, 1, 1000)
y = 1234 + (4*x1) + error

# collect data in a dataframe
df = pd.DataFrame(zip(y, x1, x2), columns=['wage', 'workhours', 'gender'])

# treat gender as categorical
df.gender = pd.Categorical(df.gender)
df.gender = pd.get_dummies(df.gender, drop_first=True)
X = df[["workhours", "gender"]]    

class SMWrapper(BaseEstimator, RegressorMixin):
    """ A universal sklearn-style wrapper for statsmodels regressors """
    def __init__(self, model_class, fit_intercept=True):
        self.model_class = model_class
        self.fit_intercept = fit_intercept
    def fit(self, X, y):
        if self.fit_intercept:
            X = sm.add_constant(X)
        self.model_ = self.model_class(y, X)
        self.results_ = self.model_.fit()
        return self
    def predict(self, X):
        if self.fit_intercept:
            X = sm.add_constant(X)
        return self.results_.predict(X)

cross_val_score(SMWrapper(sm.GLM), X, y, scoring='r2')
python scikit-learn statsmodels cross-validation glm
© www.soinside.com 2019 - 2024. All rights reserved.