通过 K-Fold 交叉验证标准化数据

问题描述 投票:0回答:2

我正在使用 StratifiedKFold 所以我的代码看起来像这样

def train_model(X,y,X_test,folds,model):
    scores=[]
    for fold_n, (train_index, valid_index) in enumerate(folds.split(X, y)):
        X_train,X_valid = X[train_index],X[valid_index]
        y_train,y_valid = y[train_index],y[valid_index]        
        model.fit(X_train,y_train)
        y_pred_valid = model.predict(X_valid).reshape(-1,)
        scores.append(roc_auc_score(y_valid, y_pred_valid))
    print('CV mean score: {0:.4f}, std: {1:.4f}.'.format(np.mean(scores), np.std(scores)))
folds = StratifiedKFold(10,shuffle=True,random_state=0)
lr = LogisticRegression(class_weight='balanced',penalty='l1',C=0.1,solver='liblinear')
train_model(X_train,y_train,X_test,repeted_folds,lr)

现在在训练模型之前我想标准化数据,那么哪个是正确的方法?
1)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

在调用 train_model 函数之前执行此操作

2)
像这样在函数内部进行标准化

def train_model(X,y,X_test,folds,model):
    scores=[]
    for fold_n, (train_index, valid_index) in enumerate(folds.split(X, y)):
        X_train,X_valid = X[train_index],X[valid_index]
        y_train,y_valid = y[train_index],y[valid_index]
        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_vaid = scaler.transform(X_valid)
        X_test = scaler.transform(X_test)
        model.fit(X_train,y_train)
        y_pred_valid = model.predict(X_valid).reshape(-1,)

        scores.append(roc_auc_score(y_valid, y_pred_valid))

    print('CV mean score: {0:.4f}, std: {1:.4f}.'.format(np.mean(scores), np.std(scores)))

根据我在第二个选项中的知识,我不会泄漏数据。那么如果我不使用管道,哪种方法是正确的,如果我想使用交叉验证,又如何使用管道?

python machine-learning pipeline cross-validation
2个回答
3
投票

确实,第二个选项更好,因为缩放器看不到

X_valid
的值来缩放
X_train

现在如果您要使用管道,您可以这样做:

from sklearn.pipeline import make_pipeline

def train_model(X,y,X_test,folds,model):
    pipeline = make_pipeline(StandardScaler(), model)
    ...

然后使用

pipeline
代替
model
。每次
fit
predict
调用时,它都会自动标准化手头的数据。

请注意,您还可以使用 scikit-learn 中的 cross_val_score 函数,并使用参数

scoring='roc_auc'


0
投票

IMO,如果您的数据很大,那么它可能并不重要(如果您使用的是 k 倍,情况可能并非如此),但既然可以,最好在交叉验证中进行(k 倍) ),或选项 2。

此外,请参阅this,了解有关交叉验证中过度拟合的更多信息。

© www.soinside.com 2019 - 2024. All rights reserved.