如何提取由DecisionTreeClassifier()和plot_tree()创建的决策树中每个节点使用的单词?

问题描述 投票:0回答:2

我有以下数据用于训练模型来检测句子是否是关于:

  • 猫或狗
  • 与猫或狗无关

我运行了以下代码来训练 DecisionTreeClassifier() 模型,然后查看树可视化:

import numpy as np
from numpy.random import seed
import random as rn
import os
import pandas as pd
seed_num = 1
os.environ['PYTHONHASHSEED'] = '0'
np.random.seed(seed_num)
rn.seed(seed_num)

from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

dummy_train = pd.read_csv('dummy_train.csv')

tree_clf = tree.DecisionTreeClassifier()

X_train = dummy_train["text"]
y_train = dummy_train["label"]

dt_tree_pipe = Pipeline([('vect', CountVectorizer(ngram_range=(1,1),
                                                 binary=True)),
                     ('tfidf', TfidfTransformer(use_idf=False)),
                      ('clf', DecisionTreeClassifier(random_state=seed_num,
                                                 class_weight={0:1, 1:1})),
                   ])

tree_model_fold_1 = dt_tree_pipe.fit(X_train, y_train)

tree.plot_tree(dt_tree_pipe["clf"])

...产生以下树:

第一个节点检查 x[7] 是否小于或等于 0.177。 如何找出 x[7] 代表哪个单词?

我尝试了以下代码,但输出中返回的单词(“描述”和“该”)看起来不正确。我本以为“猫”和“狗”是用来将数据分为正类和负类的两个词。

vect_from_pipe = dt_tree_pipe["vect"]
words = vect_from_pipe.vocabulary_.keys()
print(list(words)[7])
print(list(words)[5])

python scikit-learn nlp tree decision-tree
2个回答
0
投票

vocabulary_
属性无序;事实上,该字典的值告诉您特征索引:

词汇_:字典
术语到特征索引的映射。

由于我们已经非常清楚树中的两个特征应该是什么,因此您只需检查

vect_from_pipe.vocabulary_['cat'], vect_from_pipe.vocabulary_['dog']
看看它们是否是 5 和 7。否则,您将需要反转字典,查找 5 的值7 看看对应的按键是什么。但使用
vect_from_pipe.get_feature_names_out()
并查看其中的第 5 个和第 7 个索引会更容易。事实上,在
plot_tree
中使用它很常见:

tree.plot_tree(
    dt_tree_pipe[-1],
    feature_names = df_tree_pipe[:-1].get_feature_names_out(),
)

0
投票

scikit-learn
中,您要查找的术语是功能名称。这些是应用转换之前的输入。

在代码中,您正在访问

vocabulary_
CountVectorizer
属性,它返回一个字典,其中键是单词,值是索引。当您将键转换为列表并访问第 7 个或第 5 个元素时,它不一定对应于特征矩阵中第 7 个或第 5 个索引处的单词。

要获取特定索引对应的特征名称(单词),应该使用

get_feature_names()
CountVectorizer
方法。此方法返回按特征矩阵中相应索引排序的特征名称列表。

以下是修改代码的方法:

vect_from_pipe = dt_tree_pipe["vect"]
feature_names = vect_from_pipe.get_feature_names()
print(feature_names[7])
print(feature_names[5])

这将打印与特征矩阵中索引 7 和 5 相对应的单词。索引 7 处的单词是决策树第一次分割中使用的单词。因此,就您而言,决策树中的

x[7]
对应于您的
feature_names[7]
中的单词
CountVectorizer

© www.soinside.com 2019 - 2024. All rights reserved.