Scikit-Learn:标签在混淆矩阵中不匹配

问题描述 投票:1回答:1

假设我有一个数组(可能)有43个不同的值,例如

import pandas as pd
Y_test = pd.Series([4,4,4,42,42,0,1,1,19], dtype=int)
Y_hat = pd.Series([4,4,2,32,42,0,5,5,19], dtype=int)

[每当我尝试用以下方式绘制混淆矩阵时:

def create_conf_mat(index, y_test, y_hat):
    cm = confusion_matrix(y_test, y_hat)
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(cm)
    plt.title(f'Confusion Matrix ({index} features, 1 outcome)')
    fig.colorbar(cax)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.savefig(f'confm_{index}.png')
    plt.savefig(f'confm_{index}.svg')
    plt.savefig(f'confm_{index}.pdf')
    return

我没有得到标签[0、1、2、4、5、19、32、42],但[0、1、2、3、4、5、6、7]。我试图通过使用y_test / y_hat中的唯一值作为标签参数来显式设置标签,但是它也不起作用。我什至试图将整数值转换为字符串,但是这样做,sklearn抱怨至少一个标签必须位于y_true中。有谁知道我如何才能将y_test和y_pred中的实际值绘制为混淆矩阵中的标签?

matplotlib scikit-learn confusion-matrix
1个回答
0
投票

documentation中所提示,关于labelsconfusion_matrix参数:

如果未指定任何值,则按顺序使用在y_true或y_pred中至少出现一次的那些。

因此,我们需要将两个列表放在一起,并提取唯一编号的列表:

labels = np.unique(np.concatenate([y_test.values, y_hat.values]))
plt.xticks(range(len(labels)), labels)
plt.yticks(range(len(labels)), labels)
© www.soinside.com 2019 - 2024. All rights reserved.