因此,给 scikit-learn 函数 roc_curve 两个 True 和 False 值数组
fpr, tpr, thresholds = roc_curve(self.real_values_discrete, self.predictions_discrete)
我收到这样的值:
>>> [0. 0.63888889 1. ]
>>> [0. 0.54330709 1. ]
>>> [2 1 0]
不过,使用公式和 scikit-learn 函数手动计算 FPR 和 TPR confusion_matrix:
confusion_matrix = confusion_matrix(self.real_values_discrete, self.predictions_discrete)
print(confusion_matrix)
_tp = confusion_matrix[0, 0]
_fn = confusion_matrix[0, 1]
_fp = confusion_matrix[1, 0]
_tn = confusion_matrix[1, 1]
_tpr = _tp / (_tp + _fn)
_fpr = _fp / (_tn + _fp)
print(_fpr)
print(_tpr)
我得到这两个值
>>> 0.4566929133858268
>>> 0.3611111111111111
我不明白为什么手工计算的值和上面数组的中间值不同。
这些值是否意味着不同,或者我不明白某些内容/在某处有错误?
roc_curve()
对分数进行操作(例如 predict_proba()
的结果),而不是预测。如果使用得当,它应该返回每个可能的分类阈值的 TPR 和 FPR 值(唯一得分计数 + 1 分)。
confusion_matrix()
对预测进行操作,因此假设默认阈值为 0.5。
我认为当混淆矩阵的值影响输出时,错误出现在混淆矩阵的索引中。通常,scikit-learn中的混淆矩阵具有以下结构:
[[tn, fp],
[fn, tp]]
鉴于此,应使用以下方法进行计算:
tn = confusion_matrix[0, 0]
fp = confusion_matrix[0, 1]
fn = confusion_matrix[1, 0]
tp = confusion_matrix[1, 1]
使用可重现示例进行解释
看这个简单的例子就能很好地理解这个机制:
我应用了相同的代码,但更正了从混淆矩阵的输出中分配值的部分中的索引:
from sklearn import metrics
real_values_discrete = [0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0]
predictions_discrete = [0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0]
fpr, tpr, thresholds = metrics.roc_curve(real_values_discrete, predictions_discrete)
print(fpr)
print(tpr)
print(thresholds)
confusion_matrix = metrics.confusion_matrix(real_values_discrete, predictions_discrete)
print(confusion_matrix)
tn = confusion_matrix[0, 0]
fp = confusion_matrix[0, 1]
fn = confusion_matrix[1, 0]
tp = confusion_matrix[1, 1]
tpr_manual = tp / (tp + fn)
fpr_manual = fp / (fp + tn)
print(fpr_manual)
print(tpr_manual)
结果如下:
[0. 0.42857143 1. ]
[0. 0.5 1. ]
[inf 1. 0.]
[[4 3]
[4 4]]
0.42857142857142855
0.5
我们清楚地看到,当应用我向您展示的修改时,您会得到相同的结果(使用这两个函数,frp 为 0.4285,tpr 为 0.5)。
但是当应用你的代码时,我们将得到结果:
[0. 0.42857143 1. ]
[0. 0.5 1. ]
[inf 1. 0.]
[[4 3]
[4 4]]
0.5
0.5714285714285714
(不幸的是,这两个函数得到了不同的结果)。
结论:
在 scikit-learn 中,混淆矩阵就像
[[tn, fp],
[fn, tp]]
因此,在给 tp、fp、fn 和 tn 赋值时要注意。
应用混淆矩阵索引的修改,我确信使用
roc_curve
函数或 confusion_matrix
函数会得到相同的结果。
等待您的反馈!祝你好运!