lgb.cv 和 cross_val_score 之间的差异导致使用 LightGBM 进行多类分类

问题描述 投票:0回答:1

我预计使用

lgb.cv
cross_val_score
时会得到类似的交叉验证结果,但它们差异很大:

import lightgbm as lgb
import pandas as pd
from sklearn import datasets
from sklearn.metrics import log_loss
from sklearn.model_selection import cross_val_score

from typing import Any, Dict, List


def log_loss_scorer(clf, X, y):
    y_pred = clf.predict_proba(X)
    return log_loss(y, y_pred)


iris = datasets.load_iris()
features = pd.DataFrame(columns=["f1", "f2", "f3", "f4"], data=iris.data)
target = pd.Series(iris.target, name="target")
# 1) Native API
dataset = lgb.Dataset(features, target, feature_name=list(features.columns), free_raw_data=False)

native_params: Dict[str, Any] = {
    "objective": "multiclass", "boosting_type": "gbdt", "learning_rate": 0.05, "num_class": 3, "seed": 41
}
cv_logloss_native: float = lgb.cv(
    native_params, dataset, num_boost_round=1000, nfold=5, metrics="multi_logloss", seed=41, stratified=False,
    shuffle=False
)['valid multi_logloss-mean'][-1]

# 2) ScikitLearn API
model_scikit = lgb.LGBMClassifier(
    objective="multiclass", boosting_type="gbdt", learning_rate=0.05, n_estimators=1000, random_state=41
)
cv_logloss_scikit_list: List[float] = cross_val_score(
    model_scikit, features, target, scoring=log_loss_scorer
)
cv_logloss_scikit: float = sum(cv_logloss_scikit_list) / len(cv_logloss_scikit_list)
print(f"Native logloss CV {cv_logloss_native}; Scikit logloss CV train {cv_logloss_scikit}")

我使用本机 API 获得了

0.8803800291063604
的分数,使用
0.37528027519836027
API 获得了
scikit-learn
的分数。我尝试了不同的指标,但两种方法之间仍然得到非常不同的结果。 这种差异是否有具体原因?如何调整两种方法之间的结果?

python scikit-learn cross-validation multiclass-classification lightgbm
1个回答
0
投票

我看到了多种潜在的差异来源:

您的原生 LGBM API 代码集

stratified=False
。这可能会导致折叠不平衡。
scikit-learn
cross_val_score
自动对分类任务的折叠进行分层,确保每个类别的均衡表示。

Shuffle:您已在原生 API 中设置了

shuffle=False
,以保持数据顺序。相反,除非另有说明,
cross_val_score
将在折叠之前对数据进行打乱。

自定义评分器:您已使用

scikit-learn
的自定义评分器来计算多类对数损失。尽管您已将本机 API 指标设置为 multi_logloss,但由于实现不同,计算可能会存在细微差异。

© www.soinside.com 2019 - 2024. All rights reserved.