不知道get_n_splits的目的是什么,为什么要这样做。

问题描述 投票:0回答:1

我在跟踪一个 内核在Kaggle上 并看到了这段代码。

#Validation function
n_folds = 5

def rmsle_cv(model):
    kf = KFold(n_folds, shuffle=True, random_state=42).get_n_splits(train.values)
    rmse= np.sqrt(-cross_val_score(model, train.values, y_train, scoring="neg_mean_squared_error", cv = kf))
    return(rmse)

我明白了KFold的目的和用途,也明白了在'cross_val_score'中使用的事实,但我不明白的是为什么要使用'get_n_split'?我不明白的是为什么要使用'get_n_split'?据我所知,它返回的是用于交叉验证的迭代次数,也就是说,在这种情况下,返回的值是5。当然,对于这一行。

rmse= np.sqrt(-cross_val_score(model, train.values, y_train, scoring="neg_mean_squared_error", cv = kf))

cv = 5? 这对我来说没有任何意义。如果get_n_splits返回一个整数,为什么还要使用它呢?我想 KFold返回一个类get_n_splits 返回一个整数。

谁能帮我理清一下思路?

python scikit-learn cross-validation k-fold
1个回答
2
投票

我以为KFold会返回一个类,而 get_n_splits 返回一个整数。

当然可以 KFold 是一个类,其中一个类方法是 get_n_splits,它返回一个整数;你的所示的 kf 可变的

kf = KFold(n_folds, shuffle=True, random_state=42).get_n_splits(train.values)

不是 KFold 类对象的结果,它是 KFold().get_n_splits() 办法,而且它确实是一个整数。事实上,如果你检查一下 文件, get_n_splits() 甚至不需要任何参数(它们实际上是被忽略的,只是为了与其他类和方法的兼容性而存在)。

至于被质疑的 get_n_splits 方法,能够查询这样的类对象以获取它们的参数设置,这绝不是一个坏主意(相反);想象一下这样的情况:你有多个不同的 KFold 对象,你需要在程序流程中以编程的方式获得它们各自的CV折数。

© www.soinside.com 2019 - 2024. All rights reserved.