如何计算 K 折交叉验证的不平衡数据集的精度、召回率和 f1 分数?

问题描述 投票:0回答:2

我有一个包含二元分类问题的不平衡数据集。我已经构建了随机森林分类器并使用了 10 折的 k 折交叉验证。

kfold = model_selection.KFold(n_splits=10, random_state=42)
model=RandomForestClassifier(n_estimators=50) 

我得到了10折的结果

results = model_selection.cross_val_score(model,features,labels, cv=kfold)
print results
[ 0.60666667  0.60333333  0.52333333  0.73        0.75333333  0.72        0.7
  0.73        0.83666667  0.88666667]

我通过计算结果的均值和标准差来计算准确度

print("Accuracy: %.3f%% (%.3f%%)") % (results.mean()*100.0, results.std()*100.0)
Accuracy: 70.900% (10.345%)

我的预测如下

predictions = cross_val_predict(model, features,labels ,cv=10)

由于这是一个不平衡的数据集,我想计算每一次折叠的精度、召回率和f1分数,然后对结果进行平均。 如何计算 python 中的值?

python scikit-learn random-forest cross-validation supervised-learning
2个回答
39
投票

当你使用

cross_val_score
方法时,你可以指定,你可以计算每折的分数:

from sklearn.metrics import make_scorer, accuracy_score, precision_score, recall_score, f1_score

scoring = {'accuracy' : make_scorer(accuracy_score), 
           'precision' : make_scorer(precision_score),
           'recall' : make_scorer(recall_score), 
           'f1_score' : make_scorer(f1_score)}

kfold = model_selection.KFold(n_splits=10, random_state=42)
model=RandomForestClassifier(n_estimators=50) 

results = model_selection.cross_val_score(estimator=model,
                                          X=features,
                                          y=labels,
                                          cv=kfold,
                                          scoring=scoring)

交叉验证后,您将获得

results
字典,其键为:'accuracy'、'precision'、'recall'、'f1_score',它在每个折叠上存储特定指标的指标值。对于每个指标,您可以使用
np.mean(results[value])
np.std(results[value])
计算平均值和标准值,其中值 - 您指定的指标名称之一。


1
投票

您提到的所有分数——

accuracy
precision
recall
f1
——依赖于您(手动)为预测设置的阈值来预测类别。如果不指定阈值,默认阈值是 0.5,看这里。阈值应始终根据错误分类的成本来设置。如果没有给出成本,你应该做一个假设。

为了能够比较不同的模型或超参数,您可以考虑使用曲线下面积 (AUC) 作为 精度召回曲线,因为它通过显示不同阈值的精度和召回率而独立于阈值。在您的数据不平衡的特定情况下,PR-AUC 比 ROC 的 AUC 更合适,请参阅此处.

另见此处:https://datascience.stackexchange.com/a/96708/131238

© www.soinside.com 2019 - 2024. All rights reserved.