使用 scikit-learn 时如何找到我的树分裂的属性?

问题描述 投票:0回答:3

我一直在探索 scikit-learn,用熵和基尼分裂标准制作决策树,并探索差异。

我的问题是,我如何“打开引擎盖”并准确找出树在每个级别上分裂的属性以及它们相关的信息值,以便我可以看到两个标准在哪里做出不同的选择?

到目前为止,我已经探索了文档中概述的 9 种方法。他们似乎不允许访问此信息。但这些信息肯定是可以访问的吗?我正在设想一个包含节点和增益条目的列表或字典。

machine-learning scikit-learn decision-tree
3个回答
37
投票

直接来自文档(http://scikit-learn.org/0.12/modules/tree.html):

from io import StringIO
out = StringIO()
out = tree.export_graphviz(clf, out_file=out)
Python3 不再支持

StringIO
模块,而是导入
io
模块。

决策树对象中还有

tree_
属性,它允许直接访问整个结构。

您可以简单地阅读它

clf.tree_.children_left #array of left children
clf.tree_.children_right #array of right children
clf.tree_.feature #array of nodes splitting feature
clf.tree_.threshold #array of nodes splitting points
clf.tree_.value #array of nodes values

更多详情请查看导出方法源代码

一般来说你可以使用

inspect
模块

from inspect import getmembers
print( getmembers( clf.tree_ ) )

获取对象的所有元素

Decision tree visualization from sklearn docs


12
投票

如果您只是想快速查看树中发生的情况,请尝试:

zip(X.columns[clf.tree_.feature], clf.tree_.threshold, clf.tree_.children_left, clf.tree_.children_right)

其中X是自变量的数据框,clf是决策树对象。请注意,

clf.tree_.children_left
clf.tree_.children_right
一起包含分割的顺序(其中每一个都对应于 graphviz 可视化中的一个箭头)。


10
投票

Scikit learn 在 0.21 版本(2019 年 5 月)中引入了一种名为

export_text
的美味新方法,用于查看树中的所有规则。 文档在这里

一旦适合了模型,您只需要两行代码。首先,导入

export_text

from sklearn.tree import export_text

其次,创建一个包含您的规则的对象。为了使规则看起来更具可读性,请使用

feature_names
参数并传递功能名称列表。例如,如果您的模型名为
model
并且您的特征在名为
X_train
的数据框中命名,您可以创建一个名为
tree_rules
的对象:

tree_rules = export_text(model, feature_names=list(X_train))

然后只需打印或保存

tree_rules
。你的输出将如下所示:

|--- Age <= 0.63
|   |--- EstimatedSalary <= 0.61
|   |   |--- Age <= -0.16
|   |   |   |--- class: 0
|   |   |--- Age >  -0.16
|   |   |   |--- EstimatedSalary <= -0.06
|   |   |   |   |--- class: 0
|   |   |   |--- EstimatedSalary >  -0.06
|   |   |   |   |--- EstimatedSalary <= 0.40
|   |   |   |   |   |--- EstimatedSalary <= 0.03
|   |   |   |   |   |   |--- class: 1
© www.soinside.com 2019 - 2024. All rights reserved.