Python 中 SHAP 生成的特征重要性图的问题

问题描述 投票:0回答:1

我一直在使用 Python 中的 SHAP 库来可视化机器学习模型的特征重要性。这是我一直在使用的代码片段:

explainer = shap.TreeExplainer(trained_model_gbm)
shap_values = explainer.shap_values(x_test_selected)
shap_importance = np.abs(shap_values).mean(axis=0)
importance_df = pd.DataFrame({'features': selected_feature_names,
                              'importance': shap_importance})
importance_df.sort_values(by='importance', ascending=False, inplace=True)
print(importance_df)

shap_exp = shap.Explanation(values=shap_values, base_values=explainer.expected_value, data=x_test_selected,
                            feature_names=selected_feature_names)

shap.plots.beeswarm(shap_exp, max_display=len(selected_feature_names))

shap.plots.bar(shap_exp, max_display=len(selected_feature_names))

但是,我遇到了一个问题,生成的绘图的特征名称被裁剪,有时图像缩放效果不佳。我该如何解决这个问题?任何见解或建议将不胜感激!

尽管我尝试使用figsize参数调整图形大小并使用plt.show()修改默认图像设置,但我仍无法达到预期的结果。图中的特征名称仍在被裁剪,并且整体图像缩放仍然不一致。

python matplotlib visualization shap
1个回答
0
投票

我找到了解决 SHAP 图中特征名称被裁剪和图像缩放效果不佳的问题的解决方案。尽管该解决方案并不完美,因为它仍然截断了 y 轴上很长的特征名称,但它显着改善了绘图的整体外观和可读性。

解决方案如下:

# SHAP Explanation
explainer = shap.TreeExplainer(trained_best_model)
shap_values = explainer.shap_values(x_test_best)
shap_importance = np.abs(shap_values).mean(axis=0)
importance_df = pd.DataFrame({'features': best_features_overall,
                              'importance': shap_importance})
importance_df.sort_values(by='importance', ascending=False, inplace=True)

shap_exp = shap.Explanation(values=shap_values, base_values=explainer.expected_value, data=x_test_best,
                            feature_names=best_features_overall)

# SHAP Bee Swarm Plot
shap.plots.beeswarm(shap_exp, max_display=len(best_features_overall), show=False)
plt.ylim(-0.5,
         len(best_features_overall) - 0.5)  # Set Y-axis limits to avoid cutting off feature names
plt.subplots_adjust(left=0.5, right=0.9)  # Adjust left and right margins of the plot
plt.savefig(os.path.join(save_dir, 'shap_beeswarm.png'))
plt.close()

# SHAP Bar Plot
shap.plots.bar(shap_exp, max_display=len(best_features_overall), show=False)
plt.subplots_adjust(left=0.5, right=0.9)  # Adjust left and right margins of the plot
plt.savefig(os.path.join(save_dir, 'shap_bar.png'))
plt.close()

# Adding the line to save feature importance to a .txt file
with open('importance_of_features.txt', 'a') as file:
    file.write("\n\nSHAP Explanation - Feature Importance\n")
    file.write(importance_df.to_string(index=False))

此解决方案包括对 y 轴限制和子图边距的调整,以更好地处理蜂群图和条形图中特征名称的显示。此外,它将绘图保存为 PNG 文件,并将特征重要性写入文本文件以供进一步参考。

© www.soinside.com 2019 - 2024. All rights reserved.