使用GroupKFold在使用sklearn的嵌套交叉验证中

问题描述 投票:0回答:1

我的代码基于sklearn网站上的示例:https://scikit-learn.org/stable/auto_examples/model_selection/plot_nested_cross_validation_iris.html

我正在尝试在内部和外部简历中使用GroupKFold。

from sklearn.datasets import load_iris
from matplotlib import pyplot as plt
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV, cross_val_score, KFold,GroupKFold
import numpy as np

# Load the dataset
iris = load_iris()
X_iris = iris.data
y_iris = iris.target

# Set up possible values of parameters to optimize over
p_grid = {"C": [1, 10, 100],
          "gamma": [.01, .1]}

# We will use a Support Vector Classifier with "rbf" kernel
svm = SVC(kernel="rbf")

# Choose cross-validation techniques for the inner and outer loops,
# independently of the dataset.
# E.g "GroupKFold", "LeaveOneOut", "LeaveOneGroupOut", etc.
inner_cv = GroupKFold(n_splits=3)
outer_cv = GroupKFold(n_splits=3)

# Non_nested parameter search and scoring
clf = GridSearchCV(estimator=svm, param_grid=p_grid, cv=inner_cv)

# Nested CV with parameter optimization
nested_score = cross_val_score(clf, X=X_iris, y=y_iris, cv=outer_cv, groups=y_iris)

我知道将y值放入groups参数不是用来做的!!对于此代码,我得到以下错误。

.../anaconda3/lib/python3.7/site-packages/sklearn/model_selection/_validation.py:536: FitFailedWarning: Estimator fit failed. The score on this train-test partition for these parameters will be set to nan. Details: 
ValueError: The 'groups' parameter should not be None.

你们对解决这个问题有想法吗?

谢谢您的帮助,

Sören

python scikit-learn cross-validation
1个回答
0
投票
您可以从documentationGroupKFold中看到,当您要具有

K折叠的非重叠组时,可以使用它。这意味着,除非创建K折时需要分离不同的数据组,否则不要使用此方法。

话虽如此,对于给定的示例,您必须手动创建groups,它应该是与对象y具有相同形状的类似对象的数组。和

不同组的数量必须至少等于数量的褶皱

以下是文档中的示例代码:

import numpy as np from sklearn.model_selection import GroupKFold X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) y = np.array([1, 2, 3, 4]) groups = np.array([0, 0, 2, 2]) group_kfold = GroupKFold(n_splits=2) group_kfold.get_n_splits(X, y, groups)

您可以看到groupsy的形状相同,它有两个不同的组0, 2,与折叠数相同。
© www.soinside.com 2019 - 2024. All rights reserved.