如何根据另一个数据帧列中的值对 matplotlib 直方图着色

问题描述 投票:0回答:1

我使用

创建一个数据框
import pandas as pd 
import matplotlib.pyplot as plt 

df_dict = {
    "test_predictions": [0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4, 0.4, 0.4, 0.4, 0.4, 0.5, 0.5, 0.6, 0.6, 0.6, 0.7, 0.7, 0.7, 0.7, 0.7, 0.8, 0.8, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9],
    "y_true": [0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1],
    "distance" : [-0.1, -0.09, -0.08, -0.08, -0.07, -0.05, -0.05, -0.05, -0.05, -0.05, -0.04, -0.04, -0.04, -0.03, -0.02, -0.01, 0.01, 0.01, 0.01, 0.02, 0.03, 0.03, 0.04, 0.05, 0.05, 0.06, 0.06, 0.07, 0.08, 0.08, 0.09, 0.1]
}
df = pd.DataFrame(df_dict)

然后我使用

制作一个由两个直线图和一个直方图组成的图
fig, ax1 = plt.subplots()
ax1.plot([0, 1], [0, 1], color="red", linestyle=":", label="Perfect Model")
ax1.plot(df['test_predictions'], df['y_true'], label="NN3", color='blue')
ax2 = ax1.twinx()
ax2.hist(df['test_predictions'], bins=10, alpha=0.7, color='darkgreen', label='Histogram')

我想根据

df['distance']
中的值对直方图进行着色,并且还包括颜色图。因此,本质上直方图的一个 bin 中可能有多种颜色。任何帮助都感激不尽。谢谢!

编辑:

我在

ax2.hist(df['test_predictions'], bins=10, alpha=0.7, color='darkgreen', label='Histogram')

之前尝试过这样做
bins = np.linspace(df['test_predictions'].min(), df['test_predictions'].max(), 10)

for index, row in df.iterrows():
    bin_index = np.digitize(row['test_predictions'], bins)
    color = plt.cm.viridis(row['distance']/df['distance'].max())
    ax2.bar(bins[bin_index-1], 1, width=np.diff(bins)[0], color = color, alpha  = 0.7)

但是,我担心我使用 for 循环这一事实,因为我的实际数据帧可能包含超过 10000 行,而且当我这样做时我没有得到所需的输出,它像单独的条形图一样出现,而不是像那样堆叠起来直方图。

python pandas matplotlib data-analysis
1个回答
0
投票

您可以对属于每个条形的距离值使用

imshow

下面的代码首先创建一个额外的数据框列,其中包含每行的 bin id。然后,选择每个 bin 的距离,并将其用作

imshow()
的输入。

import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
import pandas as pd
import numpy as np

df_dict = {
    "test_predictions": [0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4, 0.4, 0.4, 0.4, 0.4, 0.5, 0.5, 0.6, 0.6, 0.6, 0.7, 0.7, 0.7, 0.7, 0.7, 0.8, 0.8, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9],
    "y_true": [0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1],
    "distance": [-0.1, -0.09, -0.08, -0.08, -0.07, -0.05, -0.05, -0.05, -0.05, -0.05, -0.04, -0.04, -0.04, -0.03, -0.02, -0.01, 0.01, 0.01, 0.01, 0.02, 0.03, 0.03, 0.04, 0.05, 0.05, 0.06, 0.06, 0.07, 0.08, 0.08, 0.09, 0.1]
}
df = pd.DataFrame(df_dict)

cmap = plt.get_cmap('RdYlBu')
norm = plt.Normalize(vmin=df['distance'].min(), vmax=df['distance'].max())

num_bins = 10
bins = np.linspace(df['test_predictions'].min(), df['test_predictions'].max() + 0.001, num_bins + 1)
df['bin'] = np.digitize(df['test_predictions'], bins)

fig, ax1 = plt.subplots()
for bin_id, bin_df in df.groupby('bin'):
    ax1.imshow(bin_df['distance'].values. Reshape(-1, 1), interpolation='nearest', cmap=cmap, norm=norm,
               extent=[bins[bin_id - 1], bins[bin_id], 0, len(bin_df)], aspect='auto')

ax1.use_sticky_edges = False # remove stickiness due to imshow
ax1.autoscale_view()
ax1.set_ylim(ymin=0)

plt.colorbar(ScalarMappable(norm=norm, cmap=cmap), label='Distance', ax=ax1)
plt.tight_layout()
plt.show()

© www.soinside.com 2019 - 2024. All rights reserved.