Matplotlib 图例用颜色映射?

问题描述 投票:0回答:2

总的来说,我对 matplotlib 真的很困惑。我通常只使用 import matplotlib.pyplot as plt。

然后执行 plt.figure()、plt.scatter()、plt.xlabel()、plt.show() 等操作。 但后来我用谷歌搜索如何做类似的事情,用颜色映射图例,我得到了所有这些包括 axe 的例子。但是有 plt.legend() 并且 matplotlib 文档中的示例只显示了 plt.legend(handles) 但没有显示句柄应该是什么。如果我想做 ax 的事情,那么我必须重写所有代码,因为我想使用 plt,因为它更简单。

这是我的代码:

import matplotlib.pyplot as plt

colmap = {
    "domestic": "blue",
    "cheetah": "red",
    "leopard": "green",
    "tiger": "black"
}

colours = []

for i in y_train:
    colours.append(colmap[i])
    
plt.figure(figsize= [15,5])
plt.scatter(X_train[:,0], X_train[:,2],c=colours)
plt.xlabel('weight')
plt.ylabel('height')
plt.grid()
plt.show()

现在我想添加一个图例,仅显示与我的字典中相同的颜色。但如果我这样做:

plt.legend(["国产","猎豹","豹子","老虎"])

它只在图例中显示“国内”,并且颜色是红色,这实际上与我的颜色编码方式不匹配。有没有办法做到这一点,而无需用“斧头”的东西重写所有内容?如果不是,我该如何适应 axe?我只是写 ax = plt.scatter(....) 吗?

python matplotlib legend
2个回答
1
投票

未提供数据,但此代码可以帮助您了解如何在 matplotlib 中向散点图添加颜色:

将 matplotlib.pyplot 导入为 plt 将 numpy 导入为 np

# data for scatter plots
x = list(range(0,30))
y = [i**2 for i in x]

# data for mapping class to color
y_train = ['domestic','cheetah', 'cheetah', 'tiger', 'domestic',
           'leopard', 'tiger', 'domestic', 'cheetah', 'domestic',
           'leopard', 'leopard', 'domestic', 'domestic', 'domestic',
           'domestic', 'cheetah', 'tiger', 'cheetah', 'cheetah',
           'domestic', 'domestic', 'domestic', 'cheetah', 'leopard',
           'cheetah', 'domestic', 'cheetah', 'tiger', 'domestic']

# color mapper
colmap = {
    "domestic": "blue",
    "cheetah": "red",
    "leopard": "green",
    "tiger": "black"
}

# create color array
colors = [colmap[i] for i in y_train]

# plot scatter    
plt.figure(figsize=(15,5))
plt.scatter(x, y, c=colors)
plt.xlabel('weight')
plt.ylabel('height')
plt.grid()
plt.show()

输出:


0
投票

RoseGod 在他的回答中给出了一个关于如何处理当前问题的很好的例子。对于使用 plt 和 ax 绘制图形之间的一般区别:

plt
调用 pyplot 库来做一些事情。这通常涉及最后打开的数字。当您同时仅使用一个图形/绘图进行简单绘图时,这种方法非常有效。
ax
是一个
Axes
对象,它引用一个特定(子)图以及该(子)图内的所有元素。这使您可以完全控制与(子)图相关的所有内容,尤其是在子图中同时绘制多个内容时。

另请参阅:matplotlib Axes.plot() 与 pyplot.plot()

© www.soinside.com 2019 - 2024. All rights reserved.