我目前正在使用 scikit-learn 的 GridSearchCV 和 Pipeline 进行超参数调整。我的估计器是 GradientBoostingRegressor。在交叉验证过程中,我遇到了一种情况,其中一个模型的折叠产生了负分,而该模型的所有其他折叠似乎都是可以接受的。这种不一致使我怀疑该特定折叠的数据中存在潜在问题。
为了进一步调查,我想访问有关在具有负分数的特定折叠中使用了哪些数据点的信息。我怎样才能做到这一点?使用 GridSearchCV 时,有没有办法识别和检索与特定折叠相关的数据点?
这是一个最小的例子:
XGBoost_param_grid = {"learning_rate": [0.001, 0.01, 0.1, 0.5],
"n_estimators": [10, 100, 500, 1000],
"max_depth": [3, 10, None],
"max_features": ["sqrt", "log2", None]}
gs_obj = GridSearchCV(estimator=GradientBoostingRegressor(), cv=5)
pipeline = Pipeline([("scaler", MinMaxScaler()), ("model", gs_obj)])
pipeline.fit(X, y)
假设最好的模型每折叠具有以下分数 (r2):
折叠0:0.5 折叠1:0.4 折叠2:0.7 折叠3:0.2 折叠 4:- 0.7
现在我想知道这些折叠中(尤其是第 4 折叠)中有哪些样本。
GridSearchCV
属性 cv_results_
返回 numpy(屏蔽)ndarray 的字典。 文档dict
您可以通过传递正确的密钥找到数据点。{
'param_kernel': masked_array(data = ['poly', 'poly', 'rbf', 'rbf'],
mask = [False False False False]...)
'param_gamma': masked_array(data = [-- -- 0.1 0.2],
mask = [ True True False False]...),
'param_degree': masked_array(data = [2.0 3.0 -- --],
mask = [False False True True]...),
'split0_test_score' : [0.80, 0.70, 0.80, 0.93],
'split1_test_score' : [0.82, 0.50, 0.70, 0.78],
'mean_test_score' : [0.81, 0.60, 0.75, 0.85],
'std_test_score' : [0.01, 0.10, 0.05, 0.08],
'rank_test_score' : [2, 4, 3, 1],
'split0_train_score' : [0.80, 0.92, 0.70, 0.93],
'split1_train_score' : [0.82, 0.55, 0.70, 0.87],
'mean_train_score' : [0.81, 0.74, 0.70, 0.90],
'std_train_score' : [0.01, 0.19, 0.00, 0.03],
'mean_fit_time' : [0.73, 0.63, 0.43, 0.49],
'std_fit_time' : [0.01, 0.02, 0.01, 0.01],
'mean_score_time' : [0.01, 0.06, 0.04, 0.04],
'std_score_time' : [0.00, 0.00, 0.00, 0.01],
'params' : [{'kernel': 'poly', 'degree': 2}, ...],
}
split{}
键对应于一个特定的折叠,其中编号。大括号内代表折叠号。
代码:import numpy as np
# Assuming you have already performed the GridSearchCV and stored the results in 'gs_obj'
# Get the index of the fold
fold_index = 4
# Get the indices of the samples used in that fold
fold_samples_indices = gs_obj.cv.split(X).__next__()[fold_index]
# Get the actual data points from the original dataset
fold_samples = X[fold_samples_indices]
# Print the samples
print(fold_samples)