以安全正确的方式使用RandomForestClassifier的predict_proba()函数。

问题描述 投票:20回答:2

我正在使用Scikit-learn在我的数据集上应用机器学习算法。有时我需要标签类的概率,而不是标签类本身。我不希望用SpamNot Spam作为邮件的标签,而是希望只用0.78的概率来表示某封邮件是Spam。

为了达到这个目的,我使用了 predict_proba() 与RandomForestClassifier如下。

clf = RandomForestClassifier(n_estimators=10, max_depth=None,
    min_samples_split=1, random_state=0)
scores = cross_val_score(clf, X, y)
print(scores.mean())

classifier = clf.fit(X,y)
predictions = classifier.predict_proba(Xtest)
print(predictions)

我得到了这些结果。

 [ 0.4  0.6]
 [ 0.1  0.9]
 [ 0.2  0.8]
 [ 0.7  0.3]
 [ 0.3  0.7]
 [ 0.3  0.7]
 [ 0.7  0.3]
 [ 0.4  0.6]

其中第二列是类: 垃圾邮件. 然而,我对结果有两个主要问题,我对这些结果没有信心。第一个问题是,结果代表了标签的概率,而没有受到我的数据大小的影响?第二个问题是结果只显示一个数字,在某些情况下,0.701的概率和0.708的概率相差很大,这不是很具体。比如说有什么办法可以得到下一个5位数吗?

python machine-learning scikit-learn random-forest
2个回答
6
投票
  1. 在我的结果中,我得到了超过一个数字,你确定这不是由于你的数据集吗? (例如,使用一个非常小的数据集将产生简单的决策树,所以 "简单 "的概率)。否则可能只是显示一个数字,但是试着打印出 predictions[0,0].

  2. 我不太明白你说的 "概率不受数据大小的影响 "是什么意思。如果您担心的是您不想预测,例如,太多的垃圾邮件,通常做的是使用一个阈值的 t 使你预测1,如果 proba(label==1) > t. 这样你就可以使用阈值来平衡你的预测,例如限制垃圾邮件的全局概率。如果您想在全球范围内分析您的模型,我们通常会计算接收者操作特征(ROC)曲线下的面积(AUC)(见维基百科文章《全球垃圾信息的概率》)。此处). 基本上,ROC曲线是对你的预测的描述,它取决于阈值 t.

希望对大家有所帮助!


21
投票

A RandomForestClassifier 是一个集合 DecisionTreeClassifier's. 无论你的训练集有多大,决策树只需返回:一个决策。一个类的概率为1,其他类的概率为0。

RandomForest只是在结果中进行投票。predict_proba() 返回每个类的投票数(森林中的每棵树都会做出自己的决定,并准确地选择一个类),除以森林中的树的数量。因此,你的精度正好是 1/n_estimators. 想要更 "精准"?添加更多的估算器。如果你想看到第5位数的变化,你将会需要 10**5 = 100,000 估算师,这是过度的。一般情况下,你不要超过100个估算师,往往也不会有那么多。

© www.soinside.com 2019 - 2024. All rights reserved.