如何在散点图上绘制多个类别的多项式模型

问题描述 投票:0回答:1

我正在使用标准钻石数据集,我需要创建以下类型的图表:

我现在所拥有的就是 1)

plt.figure(figsize=(12, 8), dpi=200)

scatterplot = sns.scatterplot(data=df, x='carat', y='price', hue='cut', palette='viridis')

sns.lineplot(data=df, x='carat', y='price', hue='cut', palette='viridis', ax=scatterplot)

plt.xlabel('Carat')
plt.ylabel('Price')
plt.title('Scatter Plot of Price vs. Carat with Curved Lines (Viridis Palette)')

plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.show()

2)

plt.figure(figsize=(12, 8), dpi=200)

cut_categories = df['cut'].unique()

for cut in cut_categories:
    data = df[df['cut'] == cut]
    sns.regplot(data=data, x='carat', y='price', scatter_kws={'s': 10}, label=cut)

plt.xlabel('Carat')
plt.ylabel('Price')
plt.title('Regression Plot of Price vs. Carat by Cut')

plt.legend(title='Cut')

plt.show()

3)

sns.jointplot(x='carat',y='price',data=df,hue='cut')

请问您能解释一下如何获得我需要的图表吗?

python-3.x numpy regression seaborn polynomials
1个回答
0
投票

数据和导入

import seaborn as sns
import numpy as np

# load data
df = sns.load_dataset('diamonds')

   carat      cut color clarity  depth  table  price     x     y     z
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75

np.polyfit
np.poly1d

# create figure and Axes
fig, ax = plt.subplots(figsize=(12, 8))

# plot the scatter points
sns.scatterplot(data=df, x='carat', y='price', hue='cut', palette='viridis', s=10, alpha=0.4, ec='none', ax=ax)

# matching palette colors from viridis
colors = palette = sns.color_palette('viridis', n_colors=len(df.cut.unique())

# iterate through the unique cuts and matching color
for cut, color in zip(df.cut.unique(), colors):

    # select the data for a given cut
    data = df[df.cut.eq(cut)]

    # create the polynomial model
    p = np.poly1d(np.polyfit(data.carat, data.price, 5))

    # create x values to pass to the model
    xp = np.linspace(data.carat.min(), data.carat.max(), 1000)

    # plot the model
    sns.lineplot(x=xp, y=p(xp), color=color, ax=ax, ls=':')

sns.move_legend(ax, bbox_to_anchor=(1, 0.5), loc='center left', frameon=False)

sns.lmplot

  • 如果
    order
    大于 1,则使用
    numpy.polyfit
    来估计多项式回归。
# plot the polynomial model
g = sns.lmplot(data=df, x='carat', y='price', hue='cut', palette='viridis', order=5, truncate=True, ci=None, scatter_kws={'s': 10, 'alpha': 1}, height=8, aspect=1.25)

# access the axes to add the manual poly model to
colors = palette = sns.color_palette('viridis', n_colors=5)

# plot the manual model for comparison
for cut, color in zip(df.cut.unique(), colors):
    data = df[df.cut.eq(cut)]
    p = np.poly1d(np.polyfit(data.carat, data.price, 5))
    xp = np.linspace(data.carat.min(), data.carat.max(), 1000)
    sns.lineplot(x=xp, y=p(xp), color='k', ax=ax, ls=':', legend=False)

© www.soinside.com 2019 - 2024. All rights reserved.