我试图创建一个包含9个子图(3 x 3)的图。X轴和Y轴的数据来自于使用groupby的数据框架。这是我的代码。
fig, axs = plt.subplots(3,3)
for index,cause in enumerate(cause_list):
df[df['CAT']==cause].groupby('RYQ')['NO_CONSUMERS'].mean().axs[index].plot()
axs[index].set_title(cause)
plt.show()
然而,它没有产生预期的输出。事实上,它返回了错误。如果我删除 axs[index]
之前 plot()
并放进 plot()
作用 plot(ax=axs[index])
然后它的工作,并产生9个子图,但没有显示数据在其中(如图所示)。
谁能指导我,我是在哪里犯的错误?
你需要将 axs
否则就是一个二维数组。而且你可以在绘图函数中提供ax,参见 大熊猫图谱,所以用一个例子。
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
cause_list = np.arange(9)
df = pd.DataFrame({'CAT':np.random.choice(cause_list,100),
'RYQ':np.random.choice(['A','B','C'],100),
'NO_CONSUMERS':np.random.normal(0,1,100)})
fig, axs = plt.subplots(3,3,figsize=(8,6))
axs = axs.flatten()
for index,cause in enumerate(cause_list):
df[df['CAT']==cause].groupby('RYQ')['NO_CONSUMERS'].mean().plot(ax=axs[index])
axs[index].set_title(cause)
plt.tight_layout()