sns kdeplot中的图四分位数

问题描述 投票:0回答:1

我想在每个轴上绘制两个核心密度估计值,并在每条曲线上绘制四分位数。我已经尝试了下面的代码。不幸的是,“ data_x_sim,data_y_sim = p2.lines [0] .get_data()”行无法读取正确的数据,四分位数始终绘制在第一个KdePlot中。为什么会这样?

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import NullFormatter
import numpy as np
import random

def plotverweildauernwithsim(df):    

    def rand_float_range(start, end):
        return random.random() * (end - start) + start

    def find_nearest(array, value):
        array = np.asarray(array)
        idx = (np.abs(array - value)).argmin()
        return idx

    df_data=df[df['method']=='data']
    df_sim=df[df['method']=='sim']    


    df['years']=df['years'].astype('float64')

    sns.set()
    #sharex=True: x-Achse für alle Subplots geteilt
    #fig,ax= plt.subplots(): tuple containing figure and axes object
    fig, ax = plt.subplots(nrows=df['cs'].max(), sharex=True)
    nrows=len(ax)
    ax[df['cs'].max()-1].set_xlabel('Verweildauer in Jahren')

    s=[(df_data["cs"] == t).sum() for t in range(1, df_data['cs'].max()+1)]
    s.insert(0, 0)
    print(s)

    for i in range(1, df['cs'].max()+1):
        quatrilesx_data = [df_data[df_data['cs']==i]["years"].astype(float).describe()['25%'],df_data[df_data['cs']==i]["years"].astype(float).describe()['50%'], df_data[df_data['cs']==i]["years"].astype(float).describe()['75%']]
        quatrilesx_sim = [df_sim[df_sim['cs']==i]["years"].astype(float).describe()['25%'],df_sim[df_sim['cs']==i]["years"].astype(float).describe()['50%'], df_sim[df_sim['cs']==i]["years"].astype(float).describe()['75%']]

        verweildauer_i_data = df_data[df_data['cs'] == i]['years']
        verweildauer_i_sim = df_sim[df_sim['cs'] == i]['years']

        cs_bandwith_i = df_data[df_data['cs'] == i]['cs_bandwith']

        width=(10*nrows)/3
        height=(7*nrows)/3
        fig.set_size_inches(width,height)
        plt.subplots_adjust(hspace=0)

        ax[i-1].scatter(verweildauer_i_data, cs_bandwith_i-i-0.5, alpha=0.1)
        #ax[i-1].scatter(verweildauer_i_sim, cs_bandwith_i-i-0.5, alpha=0.1)


        p1=sns.kdeplot(verweildauer_i_data, ax=ax[i-1], shade=True, cumulative=False)
        #Get data points of Kde PLot
        data_x_data, data_y_data = p1.lines[0].get_data()

        print(data_x_data.mean())


        p2=sns.kdeplot(verweildauer_i_sim, ax=ax[i-1], shade=True, cumulative=False)
        data_x_sim, data_y_sim = p2.lines[0].get_data()

        print(data_x_sim.mean())        

        #Plot quartiles line
        for k in range(len(quatrilesx_data)):
            p1.axvline(quatrilesx_data[k], ymin=.3255, ymax=.3255+(height/width)*abs(data_y_data[find_nearest(data_x_data, quatrilesx_data[k])]), linestyle='dotted', color='black', linewidth=.6, alpha=.8)


        for k in range(len(quatrilesx_sim)):
            p2.axvline(quatrilesx_sim[k], ymin=.3255, ymax=.3255+(height/width)*abs(data_y_sim[find_nearest(data_x_sim, quatrilesx_sim[k])]), linestyle='dashdot', color='black', linewidth=.6, alpha=.8)


        ax[i-1].set_ylim(-0.5,1)
        ax[i-1].set_ylabel('ZK ' + str(i))
        ax[i-1].set_yticks([])
        #draw the "y=0" line
        ax[i-1].axhline(0, linestyle='--', color='blue', linewidth=.5, alpha=.35) # horizontal lines


        handles, labels = ax[i-1].get_legend_handles_labels()
        # create the legend again skipping this first entry
        leg = ax[i-1].legend(handles[2:], labels[2:])

        ax2=ax[i-1]
        ax2 = ax[i-1].twinx()
        ax2.grid(False)
        ax2.set_ylabel("n = {}".format(s[i]), rotation=0, labelpad=25)
        ax2.set_yticks([])

代码结果:

“我的代码的结果”

matplotlib plot seaborn kde
1个回答
0
投票

您可能想要data_x_sim, data_y_sim = p2.lines[1].get_data()p2是该图的ax,已经包含第一个图的曲线。

这里是一些示例代码。首先,生成一些玩具数据。散点图被排除在示例之外,因为它似乎与问题无关。

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import NullFormatter
import numpy as np
import random

N = 500
df = pd.DataFrame({'method': np.repeat(['data', 'sim'], N),
                   'years': np.random.normal(np.repeat(np.random.uniform(5, 10, 10), N // 5), 1, 2 * N),
                   'cs': np.random.randint(1, 4, 2 * N)})


def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx


df_data = df[df['method'] == 'data']
df_sim = df[df['method'] == 'sim']

df['years'] = df['years'].astype('float64')

sns.set()
fig, ax = plt.subplots(nrows=df['cs'].max(), sharex=True)
nrows = len(ax)
ax[df['cs'].max() - 1].set_xlabel('Verweildauer in Jahren')

s = [(df_data["cs"] == t).sum() for t in range(1, df_data['cs'].max() + 1)]
s.insert(0, 0)

for i in range(1, df['cs'].max() + 1):
    verweildauer_i_data = df_data[df_data['cs'] == i]['years']
    verweildauer_i_sim = df_sim[df_sim['cs'] == i]['years']
    quatrilesx_data = np.quantile(verweildauer_i_data, [.25, .50, .75])
    quatrilesx_sim = np.quantile(verweildauer_i_sim, [.25, .50, .75])

    width = (10 * nrows) / 3
    height = (7 * nrows) / 3
    fig.set_size_inches(width, height)
    plt.subplots_adjust(hspace=0)

    p1 = sns.kdeplot(verweildauer_i_data, ax=ax[i - 1], shade=True, cumulative=False, legend=True, label='data')
    # Get data points of Kde PLot
    data_x_data, data_y_data = p1.lines[0].get_data()

    p2 = sns.kdeplot(verweildauer_i_sim, ax=ax[i - 1], shade=True, cumulative=False, legend=True, label='sim')
    data_x_sim, data_y_sim = p2.lines[1].get_data()

    # Plot quartiles line
    for k in range(len(quatrilesx_data)):
        p1.axvline(quatrilesx_data[k], ymin=.3255,
                   ymax=.3255 + (height / width) * abs(data_y_data[find_nearest(data_x_data, quatrilesx_data[k])]),
                   linestyle='dotted', color='black', linewidth=.6, alpha=.8)

    for k in range(len(quatrilesx_sim)):
        p2.axvline(quatrilesx_sim[k], ymin=.3255,
                   ymax=.3255 + (height / width) * abs(data_y_sim[find_nearest(data_x_sim, quatrilesx_sim[k])]),
                   linestyle='dashdot', color='black', linewidth=.6, alpha=.8)

    ax[i - 1].set_ylim(-0.5, 1)
    ax[i - 1].set_ylabel('ZK ' + str(i))
    ax[i - 1].set_yticks([])
    # draw the "y=0" line
    ax[i - 1].axhline(0, linestyle='--', color='blue', linewidth=.5, alpha=.35)  # horizontal lines

    ax2 = ax[i - 1]
    ax2 = ax[i - 1].twinx()
    ax2.grid(False)
    ax2.set_ylabel("n = {}".format(s[i]), rotation=0, labelpad=25)
    ax2.set_yticks([])

plt.show()

example plot

© www.soinside.com 2019 - 2024. All rights reserved.