我正在使用plot_confusion_matrix
中的sklearn.metrics
。我想像子图一样将那些混淆矩阵彼此相邻。
让我们使用good'ol虹膜数据集来重现此数据,并使用plot_confusion_matrix
拟合多个分类器以绘制各自的混淆矩阵:
plot_confusion_matrix
设置-
from sklearn.ensemble import AdaBoostClassifier, GradientBoostingClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import plot_confusion_matrix
from matplotlib import pyplot as plt
data = load_iris()
X = data.data
y = data.target
因此,您可以简单地比较所有矩阵,方法是X_train, X_test, y_train, y_test = train_test_split(X, y)
lg = LogisticRegression(solver = 'lbfgs')
ab = AdaBoostClassifier()
gb = GradientBoostingClassifier()
svc = SVC()
classifiers = [lg,ab,gb,svc]
for i, cls in enumerate(classifiers):
cls.fit(X_train, y_train)
,然后迭代返回的轴对象,并使用plt.subplots
绘制各个混淆矩阵:
plot_confusion_matrix
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(15,10))
axes = axes.flatten()
for i, ax in enumerate(axes.flatten()):
cls = classifiers[i]
plot_confusion_matrix(cls,
X_test,
y_test,
ax=ax,
cmap='Blues',
display_labels=data.target_names)
ax.title.set_text(type(cls).__name__)
plt.show()