[使用字符串数据框在Python中绘制混淆矩阵

问题描述 投票:0回答:1

我正在进行10倍验证,我需要查看每个类的准确性如何变化。我设法创建了一个这样的DataFrame:

Snippet:

chars = []
for i in range(0, int(classes) + 1):
    row = []
    for j in range(0, int(classes) + 1):
        row.append(str(round(means[i, j], 3)) + " +/- " + str(round(stds[i, j], 3)))
    chars.append(row)

con_mat_df = pd.DataFrame(chars, index=classes_list, columns=classes_list)
           0                1   ...               14               15
0    100.0 +/- 0.0      0.0 +/- 0.0  ...      0.0 +/- 0.0      0.0 +/- 0.0
1   0.49 +/- 0.703  98.53 +/- 1.416  ...      0.0 +/- 0.0      0.0 +/- 0.0
2      0.0 +/- 0.0    0.12 +/- 0.36  ...      0.0 +/- 0.0      0.0 +/- 0.0
3      0.0 +/- 0.0      0.0 +/- 0.0  ...      0.0 +/- 0.0      0.0 +/- 0.0
4      0.0 +/- 0.0      0.0 +/- 0.0  ...      0.0 +/- 0.0      0.0 +/- 0.0
5   0.55 +/- 0.905    0.14 +/- 0.42  ...      0.0 +/- 0.0      0.0 +/- 0.0
6      0.0 +/- 0.0      0.0 +/- 0.0  ...      0.0 +/- 0.0      0.0 +/- 0.0
7      0.0 +/- 0.0      0.0 +/- 0.0  ...      0.0 +/- 0.0      0.0 +/- 0.0
8      0.0 +/- 0.0      0.0 +/- 0.0  ...      0.0 +/- 0.0      0.0 +/- 0.0
9   0.62 +/- 1.318      0.2 +/- 0.6  ...      0.0 +/- 0.0      0.0 +/- 0.0
10  0.65 +/- 0.927   0.24 +/- 0.265  ...      0.0 +/- 0.0      0.0 +/- 0.0
11  1.02 +/- 1.558      0.0 +/- 0.0  ...      0.0 +/- 0.0   1.36 +/- 1.482
12     0.0 +/- 0.0      0.0 +/- 0.0  ...      0.0 +/- 0.0      0.0 +/- 0.0
13   0.32 +/- 0.96      0.0 +/- 0.0  ...      0.0 +/- 0.0      0.0 +/- 0.0
14  0.78 +/- 1.191      0.0 +/- 0.0  ...  98.96 +/- 1.274      0.0 +/- 0.0
15     0.0 +/- 0.0      0.0 +/- 0.0  ...      0.0 +/- 0.0  94.78 +/- 6.884
[16 rows x 16 columns]

现在,我只希望能够按照以下示例进行绘制。我想知道怎么做。如果我使用sns.heatmap,它将引发错误(TypeError: ufunc 'isnan' not supported for the input types...)。有任何想法吗?谢谢。

enter image description here

python pandas matplotlib heatmap confusion-matrix
1个回答
0
投票

所以我发现的最简单的方法是(cm是均值的数组,cms是标准差的数组:]

def plot_confusion_matrix(cm, cms,  classes,
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    thresh = cm.max() / 2.

    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, '{0:.2f}'.format(cm[i, j]) + '\n$\pm$' + '{0:.2f}'.format(cms[i, j]),
                     horizontalalignment="center",
                     verticalalignment="center", fontsize=7,
                     color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')


# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(means, stds, classes=classes_list)

enter image description here

© www.soinside.com 2019 - 2024. All rights reserved.