我正在进行10倍验证,我需要查看每个类的准确性如何变化。我设法创建了一个这样的DataFrame:
Snippet:
chars = []
for i in range(0, int(classes) + 1):
row = []
for j in range(0, int(classes) + 1):
row.append(str(round(means[i, j], 3)) + " +/- " + str(round(stds[i, j], 3)))
chars.append(row)
con_mat_df = pd.DataFrame(chars, index=classes_list, columns=classes_list)
0 1 ... 14 15
0 100.0 +/- 0.0 0.0 +/- 0.0 ... 0.0 +/- 0.0 0.0 +/- 0.0
1 0.49 +/- 0.703 98.53 +/- 1.416 ... 0.0 +/- 0.0 0.0 +/- 0.0
2 0.0 +/- 0.0 0.12 +/- 0.36 ... 0.0 +/- 0.0 0.0 +/- 0.0
3 0.0 +/- 0.0 0.0 +/- 0.0 ... 0.0 +/- 0.0 0.0 +/- 0.0
4 0.0 +/- 0.0 0.0 +/- 0.0 ... 0.0 +/- 0.0 0.0 +/- 0.0
5 0.55 +/- 0.905 0.14 +/- 0.42 ... 0.0 +/- 0.0 0.0 +/- 0.0
6 0.0 +/- 0.0 0.0 +/- 0.0 ... 0.0 +/- 0.0 0.0 +/- 0.0
7 0.0 +/- 0.0 0.0 +/- 0.0 ... 0.0 +/- 0.0 0.0 +/- 0.0
8 0.0 +/- 0.0 0.0 +/- 0.0 ... 0.0 +/- 0.0 0.0 +/- 0.0
9 0.62 +/- 1.318 0.2 +/- 0.6 ... 0.0 +/- 0.0 0.0 +/- 0.0
10 0.65 +/- 0.927 0.24 +/- 0.265 ... 0.0 +/- 0.0 0.0 +/- 0.0
11 1.02 +/- 1.558 0.0 +/- 0.0 ... 0.0 +/- 0.0 1.36 +/- 1.482
12 0.0 +/- 0.0 0.0 +/- 0.0 ... 0.0 +/- 0.0 0.0 +/- 0.0
13 0.32 +/- 0.96 0.0 +/- 0.0 ... 0.0 +/- 0.0 0.0 +/- 0.0
14 0.78 +/- 1.191 0.0 +/- 0.0 ... 98.96 +/- 1.274 0.0 +/- 0.0
15 0.0 +/- 0.0 0.0 +/- 0.0 ... 0.0 +/- 0.0 94.78 +/- 6.884
[16 rows x 16 columns]
现在,我只希望能够按照以下示例进行绘制。我想知道怎么做。如果我使用sns.heatmap
,它将引发错误(TypeError: ufunc 'isnan' not supported for the input types...
)。有任何想法吗?谢谢。
所以我发现的最简单的方法是(cm是均值的数组,cms是标准差的数组:]
def plot_confusion_matrix(cm, cms, classes,
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
plt.text(j, i, '{0:.2f}'.format(cm[i, j]) + '\n$\pm$' + '{0:.2f}'.format(cms[i, j]),
horizontalalignment="center",
verticalalignment="center", fontsize=7,
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(means, stds, classes=classes_list)