我有以下数据用于训练模型来检测句子是否是关于:
我运行了以下代码来训练 DecisionTreeClassifier() 模型,然后查看树可视化:
import numpy as np
from numpy.random import seed
import random as rn
import os
import pandas as pd
seed_num = 1
os.environ['PYTHONHASHSEED'] = '0'
np.random.seed(seed_num)
rn.seed(seed_num)
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
dummy_train = pd.read_csv('dummy_train.csv')
tree_clf = tree.DecisionTreeClassifier()
X_train = dummy_train["text"]
y_train = dummy_train["label"]
dt_tree_pipe = Pipeline([('vect', CountVectorizer(ngram_range=(1,1),
binary=True)),
('tfidf', TfidfTransformer(use_idf=False)),
('clf', DecisionTreeClassifier(random_state=seed_num,
class_weight={0:1, 1:1})),
])
tree_model_fold_1 = dt_tree_pipe.fit(X_train, y_train)
tree.plot_tree(dt_tree_pipe["clf"])
...产生以下树:
第一个节点检查 x[7] 是否小于或等于 0.177。 如何找出 x[7] 代表哪个单词?
我尝试了以下代码,但输出中返回的单词(“描述”和“该”)看起来不正确。我本以为“猫”和“狗”是用来将数据分为正类和负类的两个词。
vect_from_pipe = dt_tree_pipe["vect"]
words = vect_from_pipe.vocabulary_.keys()
print(list(words)[7])
print(list(words)[5])
vocabulary_
属性无序;事实上,该字典的值告诉您特征索引:
词汇_:字典
术语到特征索引的映射。
由于我们已经非常清楚树中的两个特征应该是什么,因此您只需检查
vect_from_pipe.vocabulary_['cat'], vect_from_pipe.vocabulary_['dog']
看看它们是否是 5 和 7。否则,您将需要反转字典,查找 5 的值7 看看对应的按键是什么。但使用 vect_from_pipe.get_feature_names_out()
并查看其中的第 5 个和第 7 个索引会更容易。事实上,在plot_tree
中使用它很常见:
tree.plot_tree(
dt_tree_pipe[-1],
feature_names = df_tree_pipe[:-1].get_feature_names_out(),
)
在
scikit-learn
中,您要查找的术语是功能名称。这些是应用转换之前的输入。
在代码中,您正在访问
vocabulary_
的 CountVectorizer
属性,它返回一个字典,其中键是单词,值是索引。当您将键转换为列表并访问第 7 个或第 5 个元素时,它不一定对应于特征矩阵中第 7 个或第 5 个索引处的单词。
要获取特定索引对应的特征名称(单词),应该使用
get_feature_names()
的CountVectorizer
方法。此方法返回按特征矩阵中相应索引排序的特征名称列表。
以下是修改代码的方法:
vect_from_pipe = dt_tree_pipe["vect"]
feature_names = vect_from_pipe.get_feature_names()
print(feature_names[7])
print(feature_names[5])
这将打印与特征矩阵中索引 7 和 5 相对应的单词。索引 7 处的单词是决策树第一次分割中使用的单词。因此,就您而言,决策树中的
x[7]
对应于您的 feature_names[7]
中的单词 CountVectorizer
。