我一直在开发一个 Jupyter Notebook,它接收 CSV 文件,对其进行操作并生成各种模型和视觉元素来描述它们。
我使用的一个工具是来自 scikit-learn 的混淆矩阵,最初我使用
plot_confusion_matrix
函数,但是自从通过 pip
更新后,我注意到这个函数已经贬值并删除了。而是替换为 ConfusionMatrixDisplay
我发现很难直接切换函数而不产生错误,有谁知道如何为当前的
scikit-learn
函数重写它们吗?
def plot_confusMatrix(cm, classes,
title='Confusion matrix',
cmap=plt.cm.Blues):
plt.rcParams.update({'font.size': 19})
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title,fontdict={'size':'16'})
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45,fontsize=12,color="blue")
plt.yticks(tick_marks, classes,fontsize=12,color="blue")
rc('font', weight='bold')
fmt = '.1f'
thresh = cm.max()
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="red")
plt.ylabel('True label',fontdict={'size':'16'})
plt.xlabel('Predicted label',fontdict={'size':'16'})
plt.tight_layout()
plot_confusMatrix(confusion_matrix(y_test, y_pred=y_pred), classes=['Non Fraud','Fraud'],
title='Confusion matrix')
和
plot_confusion_matrix(confusion_matrix(y_test, y_pred=y_pred), classes=['Non Fraud','Fraud'],
title='Confusion matrix')
它曾经产生过,但我无法使用新的混淆矩阵函数
尝试使用此代码
cm = confusion_matrix(y_test, predictions, labels=clf.classes_)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=clf.classes_) disp.plot()
plt.show()
或者看看这个主题如何绘制混淆矩阵?