使用plot_confusion_matrix绘制多个混淆矩阵

问题描述 投票:0回答:1

我正在使用plot_confusion_matrix中的sklearn.metrics。我想像子图一样将那些混淆矩阵彼此相邻。

python matplotlib machine-learning scikit-learn seaborn
1个回答
0
投票

让我们使用good'ol虹膜数据集来重现此数据,并使用plot_confusion_matrix拟合多个分类器以绘制各自的混淆矩阵:

plot_confusion_matrix

设置-

from sklearn.ensemble import AdaBoostClassifier, GradientBoostingClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import plot_confusion_matrix
from matplotlib import pyplot as plt

data = load_iris()
X = data.data
y = data.target

因此,您可以简单地比较所有矩阵,方法是X_train, X_test, y_train, y_test = train_test_split(X, y) lg = LogisticRegression(solver = 'lbfgs') ab = AdaBoostClassifier() gb = GradientBoostingClassifier() svc = SVC() classifiers = [lg,ab,gb,svc] for i, cls in enumerate(classifiers): cls.fit(X_train, y_train) ,然后迭代返回的轴对象,并使用plt.subplots绘制各个混淆矩阵:

plot_confusion_matrix

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(15,10)) axes = axes.flatten() for i, ax in enumerate(axes.flatten()): cls = classifiers[i] plot_confusion_matrix(cls, X_test, y_test, ax=ax, cmap='Blues', display_labels=data.target_names) ax.title.set_text(type(cls).__name__) plt.show()

© www.soinside.com 2019 - 2024. All rights reserved.