有没有办法仅根据混淆矩阵的非对角元素制作条形图?

问题描述 投票:0回答:1

我正在运行 3 个深度类别的地震(浅层地震、中层地震、深层地震)和第 4 类噪声的分类任务。我试图通过从混淆矩阵的非对角元素构建条形图来可视化分类中错误标记的不同大小。

制作混淆矩阵并删除对角线元素:

cm = confusion_matrix(df_val, pred_labels, labels=inp.labels_list, normalize='true')
cm=cm[~np.eye(cm.shape[0],dtype=bool)].reshape(cm.shape[0],-1) #removes diagonal components
frequency = cm.flatten() #making into row vector

构建数据框:

mislabel_counts={
    'depth categories': ('shallow','shallow','shallow','inter','inter','inter','deep','deep','deep','noise','noise','noise'),
    'mislabelled as:': ('inter','deep','noise','shallow','deep','noise','shallow','inter','noise','shallow','inter','deep'),
    'frequency': tuple(frequency),
    }  
mislabel_counts=pd.DataFrame(mislabel_counts)

DataFrame 最终看起来像这样:

   depth categories    mislabelling  frequency
0           shallow           inter   0.362903
1           shallow            deep   0.209677
2           shallow           noise   0.233871
3             inter         shallow   0.177778
4             inter            deep   0.229630
5             inter           noise   0.296296
6              deep         shallow   0.183486
7              deep           inter   0.348624
8              deep           noise   0.238532
9             noise         shallow   0.166667
10            noise           inter   0.318182
11            noise            deep   0.257576

绘图:

sns.set_theme(font_scale=2.5)
sns.color_palette('colorblind')
sns.barplot(data=mislabel_counts, x='depth categories', y='frequency', hue='mislabelled as:', width=0.5)

剧情最终是这样的:

我的问题是,该图为对角线组件节省了空间,即使它们已被删除。有谁知道如何防止这种情况发生吗?

python plot scikit-learn seaborn bar-chart
1个回答
0
投票

最简单的解决方案是使用

sns.catplot
创建一个
sharex=False
:

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

mislabel_counts = {
    'depth categories': ('shallow', 'shallow', 'shallow', 'inter', 'inter', 'inter', 'deep', 'deep', 'deep', 'noise', 'noise', 'noise'),
    'mislabelled as:': ('inter', 'deep', 'noise', 'shallow', 'deep', 'noise', 'shallow', 'inter', 'noise', 'shallow', 'inter', 'deep'),
    'frequency': [.362903, .209677, .233871, .177778, .229630, .296296, .183486, .348624, .238532, .166667, .318182, .257576]
}
mislabel_counts = pd.DataFrame(mislabel_counts)

sns.set()
sns.catplot(data=mislabel_counts, kind='bar', x='mislabelled as:', y='frequency', col='depth categories',
            hue='mislabelled as:', dodge=False, sharex=False)

plt.show()

© www.soinside.com 2019 - 2024. All rights reserved.