我正在尝试使用 matplotlib.pytplot 绘制线图,并拥有形状为
(42,7)
的数据框 df 。数据框具有以下结构(仅显示相关列):
timepoint value point
2021-01-01 10 0
2021-02-01 20 0
....
2021-11-01 10 0
2021-12-01 50 1
2022-01-01 60 1
...
我尝试按以下方式绘制条件颜色(每个点 = 0 的值是蓝色,每个点 = 1 的值是红色):
import numpy as np
col = np.where(df['point'] == 0, 'b', 'r')
plt.plot(df['timepoint'], df['value'], c=col)
plt.show()
我收到错误消息:
值错误:数组(['b','b','b','b','b','b','b','b','b','b', 'b', 'b', 'b', 'b'、'r'、'r'、'r'、'r'、'r'、'r'、'r'、'r'、'r'、'r'、'r'、'r ', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r', 'r ', 'r', 'r', 'r'], dtype='
当我查看这个问题 ValueError: Invalid RGBA argument: What is Cause this error? 时,我找不到任何解决方案,因为我的颜色数组的形状是:
col.shape
是 (42, )
np.where
来定义折线图的颜色(就像他们在 here 中使用散点图那样),因为在前者中, c
参数需要单个整条线的颜色,而在后者中,您可以提供将映射到点的颜色数组。
因此,这是解决这个问题的一种可能的选择:
(_, p0), (_, others) = df.groupby(df["point"].eq(0), sort=False)
plt.plot(p0.timepoint, p0.value, c="b", label="0 Points")
plt.plot(others.timepoint, others.value, c="r", label="Others")
plt.legend()
plt.show();
pivot
/ plot
:
other_points = df["point"].loc[lambda s: s.ne(0)].unique()
(
df.pivot(index="timepoint", columns="point", values="value")
.plot(style={0: "b", **{k: "r" for k in other_points}})
)
输出:
您遇到的错误是因为 plt.plot() 函数中的 c 参数需要每个数据点的有效颜色规范,但您提供了一个颜色名称数组('b' 表示蓝色,'r'红色)这不是 c 参数的有效值。
要为折线图绘制条件颜色,您可以使用循环迭代 DataFrame 行并使用所需的颜色绘制每个段。以下是如何做到这一点(使用示例数据框):
import pandas as pd
import matplotlib.pyplot as plt
# Sample DataFrame
data = {
'timepoint': ['2021-01-01', '2021-02-01', '2021-11-01', '2021-12-01', '2022-01-01'],
'value': [10, 20, 10, 50, 60],
'point': [0, 0, 0, 1, 1]
}
df = pd.DataFrame(data)
# Convert 'timepoint' column to datetime type
df['timepoint'] = pd.to_datetime(df['timepoint'])
# Initialize variables to keep track of the previous point value and color
prev_point = None
prev_color = None
fig, ax = plt.subplots()
for index, row in df.iterrows():
# Check the 'point' value for the current row
current_point = row['point']
# Set the color based on the 'point' value
if current_point == 0:
color = 'b' # Blue for point=0
else:
color = 'r' # Red for point=1
# If the point value has changed, create a new segment in the plot
if prev_point is not None and current_point != prev_point:
ax.plot(df.loc[df.index[index - 1]:index - 1, 'timepoint'], df.loc[df.index[index - 1]:index - 1, 'value'], c=prev_color, label=f'Point {prev_point}', marker='o') # Add marker='o' to display points
ax.plot([row['timepoint']], [row['value']], c=color, marker='o') # Add marker='o' to display points
# Update previous point and color
prev_point = current_point
prev_color = color
# Add labels, legend, and show the plot
ax.set_xlabel('Timepoint')
ax.set_ylabel('Value')
ax.legend()
plt.xticks(rotation=45) # Rotate x-axis labels for better readability
plt.show()
我们也可以使用
np.where
方法。我们可以使用 plt.scatter
来实现这一点,如下面的代码:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# Sample DataFrame
data = {
'timepoint': ['2021-01-01', '2021-02-01', '2021-11-01', '2021-12-01', '2022-01-01'],
'value': [10, 20, 10, 50, 60],
'point': [0, 0, 0, 1, 1]
}
df = pd.DataFrame(data)
# Convert 'timepoint' column to datetime type
df['timepoint'] = pd.to_datetime(df['timepoint'])
# Use np.where to conditionally set colors
col = np.where(df['point'] == 0, 'b', 'r')
# Create the scatter plot with conditional colors
plt.scatter(df['timepoint'], df['value'], c=col, s=50)
# Show the plot
plt.xlabel('Timepoint')
plt.ylabel('Value')
plt.title('Conditional Scatter Plot')
plt.xticks(rotation=45) # Rotate x-axis labels for better readability
plt.grid(True)
plt.show()
这是创建线图的更新代码,但使用循环来区分颜色:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# Sample DataFrame
data = {
'timepoint': ['2021-01-01', '2021-02-01', '2021-11-01', '2021-12-01', '2022-01-01'],
'value': [10, 20, 10, 50, 60],
'point': [0, 0, 0, 1, 1]
}
df = pd.DataFrame(data)
# Convert 'timepoint' column to datetime type
df['timepoint'] = pd.to_datetime(df['timepoint'])
# Use np.where to conditionally set colors
col = np.where(df['point'] == 0, 'b', 'r')
# Create the line graph with conditional colors
fig, ax = plt.subplots()
for color, group in df.groupby(col):
ax.plot(group['timepoint'], group['value'], label=color)
# Add labels, legend, and show the plot
ax.set_xlabel('Timepoint')
ax.set_ylabel('Value')
plt.xticks(rotation=45) # Rotate x-axis labels for better readability
plt.grid(True)
plt.legend()
plt.show()
你可以查看我的kaggle这里