如何获取 pandas.DataFrame.plot 创建的绘图的内部创建的颜色条实例?
以下是生成彩色散点图的示例:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import itertools as it
# [ (0,0), (0,1), ..., (9,9) ]
xy_positions = list( it.product( range(10), range(10) ) )
df = pd.DataFrame( xy_positions, columns=['x','y'] )
# draw 100 floats
df['score'] = np.random.random( 100 )
ax = df.plot( kind='scatter',
x='x',
y='y',
c='score',
s=500)
ax.set_xlim( [-0.5,9.5] )
ax.set_ylim( [-0.5,9.5] )
plt.show()
如何获取颜色条实例以对其进行操作,例如更改标签或设置刻度?
pandas
不会返回颜色条的轴,因此我们必须找到它:
首先,让我们获取
figure
实例:即使用 plt.gcf()
In [61]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import itertools as it
# [ (0,0), (0,1), ..., (9,9) ]
xy_positions = list( it.product( range(10), range(10) ) )
df = pd.DataFrame( xy_positions, columns=['x','y'] )
# draw 100 floats
df['score'] = np.random.random( 100 )
ax = df.plot( kind='scatter',
x='x',
y='y',
c='score',
s=500)
ax.set_xlim( [-0.5,9.5] )
ax.set_ylim( [-0.5,9.5] )
f = plt.gcf()
2、这个图形有多少个轴?
In [62]:
f.get_axes()
Out[62]:
[<matplotlib.axes._subplots.AxesSubplot at 0x120a4d450>,
<matplotlib.axes._subplots.AxesSubplot at 0x120ad0050>]
3,第一个轴(即创建的第一个轴)包含绘图
In [63]:
ax
Out[63]:
<matplotlib.axes._subplots.AxesSubplot at 0x120a4d450>
4, 因此,第二个轴是颜色条轴
In [64]:
cax = f.get_axes()[1]
#and we can modify it, i.e.:
cax.set_ylabel('test')
不太一样,但你可以使用 matplotlib 进行绘图:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import itertools as it
# [ (0,0), (0,1), ..., (9,9) ]
xy_positions = list( it.product( range(10), range(10) ) )
df = pd.DataFrame( xy_positions, columns=['x','y'] )
# draw 100 floats
df['score'] = np.random.random( 100 )
fig = plt.figure()
ax = fig.add_subplot(111)
s = ax.scatter(df.x, df.y, c=df.score, s=500)
cb = plt.colorbar(s)
cb.set_label('desired_label')
ax.set_xlim( [-0.5,9.5] )
ax.set_ylim( [-0.5,9.5] )
plt.show()
回答原来的问题:可以设置
colorbar=False
并单独生成。它需要一个“可映射的”,即包含颜色图信息的 matplotlib 对象。这里是存储在ax.collections[0]
中的散点。通过返回值 cbar = plt.colorbar(...)
,您可以访问颜色条及其 ax
。
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
import pandas as pd
import numpy as np
# [ (0,0), (0,1), ..., (9,9) ]
df = pd.DataFrame({'x': np.repeat(range(10), 10), 'y': np.tile(range(10), 10)})
# draw 100 floats
df['score'] = np.random.random(100)
ax = df.plot(kind='scatter',
x='x',
y='y',
c='score',
s=500,
cmap='RdYlGn',
vmin=0,
vmax=1,
colorbar=False)
ax.set_xlim([-0.5, 9.5])
ax.set_ylim([-0.5, 9.5])
ax.xaxis.set_major_locator(MultipleLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(1))
cbar = plt.colorbar(ax.collections[0], ax=ax)
cbar.set_ticks([0, 0.5, 1])
cbar.set_ticklabels(['low', 'medium', 'high'])
cbar.ax.set_title('Score', ha='left')
plt.tight_layout()
plt.show()