不确定get_n_splits的目的以及为什么有必要(如果有的话)

问题描述 投票:0回答:1

我正在跟踪kernel on Kaggle,并遇到了此代码。

#Validation function
n_folds = 5

def rmsle_cv(model):
    kf = KFold(n_folds, shuffle=True, random_state=42).get_n_splits(train.values)
    rmse= np.sqrt(-cross_val_score(model, train.values, y_train, scoring="neg_mean_squared_error", cv = kf))
    return(rmse)

我了解KFold的目的和用法,以及'cross_val_score'中使用的事实。我不明白的是为什么要使用'get_n_split'?据我所知,它返回用于交叉验证的迭代次数,即在这种情况下返回值5。当然这行:

rmse= np.sqrt(-cross_val_score(model, train.values, y_train, scoring="neg_mean_squared_error", cv = kf))

cv = 5?这对我来说毫无意义。如果它返回整数,为什么还要使用get_n_splits?我以为KFold returns a classget_n_splits返回一个整数。

任何人都可以清除我的理解吗?

python scikit-learn cross-validation k-fold
1个回答
0
投票

我以为KFold返回一个类,而get_n_splits返回一个整数。

当然,KFold是一个类,并且其中一个类方法是get_n_splits,它返回一个整数;您显示的kf变量

kf = KFold(n_folds, shuffle=True, random_state=42).get_n_splits(train.values)

不是KFold类对象,它是KFold().get_n_splits() 方法的结果,并且确实是整数。实际上,如果您检查documentation,则get_n_splits()甚至不需要任何参数(它们实际上被忽略,并且仅出于与其他类和方法的兼容性原因而存在)。

关于get_n_splits方法的实用程序,能够查询这样的类对象以取回其参数设置(相反)永远不是一个坏主意;想象一下您有多个不同的KFold对象,并且需要在程序流程中以编程方式获取它们各自的CV折叠数的情况。

© www.soinside.com 2019 - 2024. All rights reserved.