如何将误差线添加到交互图(统计模型)?

问题描述 投票:0回答:2

我有以下代码:

import numpy as np
import matplotlib.pyplot as plt
from statsmodels.graphics.factorplots import interaction_plot

a = np.array( [ item for item in [ 'a1', 'a2', 'a3' ] for _ in range(30) ] )
b = np.array( [ item for _ in range(45) for item in [ 'b1', 'b2' ] ] )
np.random.seed(123)
mse = np.ravel( np.column_stack( (np.random.normal(-1, 1, size=45 ), np.random.normal(2, 0.5, size=45 ) )) )
f = interaction_plot( a, b, mse )

这给出了:


有没有一种简单的方法可以直接为每个点添加误差线?

f.axes.errorbar()?

或者直接用 matplotlib 绘图更好?

python matplotlib plot statsmodels
2个回答
1
投票

嗯,看来该功能还没有直接支持,所以我决定直接修改源代码并创建一个新功能。我把它贴在这里,也许对某人有用。


def int_plot(x, trace, response, func=np.mean, ax=None, plottype='b',
                     xlabel=None, ylabel=None, colors=[], markers=[],
                     linestyles=[], legendloc='best', legendtitle=None,
# - - - My changes !!
                     errorbars=False, errorbartyp='std',
# - - - - 
                     **kwargs):

    data = DataFrame(dict(x=x, trace=trace, response=response))
    plot_data = data.groupby(['trace', 'x']).aggregate(func).reset_index()

# - - - My changes !!
    if errorbars:
        if errorbartyp == 'std':
            yerr = data.groupby(['trace', 'x']).aggregate( lambda xx: np.std(xx,ddof=1) ).reset_index()
        elif errorbartyp == 'ci95':
            yerr = data.groupby(['trace', 'x']).aggregate( t_ci ).reset_index()
        else:
            raise ValueError("Type of error bars %s not understood" % errorbartyp)
# - - - - - - -
    n_trace = len(plot_data['trace'].unique())

    if plottype == 'both' or plottype == 'b':
        for i, (values, group) in enumerate(plot_data.groupby(['trace'])):
            # trace label
            label = str(group['trace'].values[0])
# - - - My changes !!
            if errorbars:
                ax.errorbar(group['x'], group['response'], 
                            yerr=yerr.loc[ yerr['trace']==values ]['response'].values, 
                        color=colors[i], ecolor='black',
                        marker=markers[i], label='',
                        linestyle=linestyles[i], **kwargs)
# - - - - - - - - - - 
            ax.plot(group['x'], group['response'], color=colors[i],
                    marker=markers[i], label=label,
                    linestyle=linestyles[i], **kwargs)

这样,我就可以得到这个情节:

f = int_plot( a, b, mse, errorbars=True, errorbartyp='std' )


注意:代码还可以使用函数

t_ci()
来聚合误差线。我这样定义函数:

def t_ci( x, C=0.95 ):
    from scipy.stats import t

    x = np.array( x )
    n = len( x )
    tstat = t.ppf( (1-C)/2, n )
    return np.std( x, ddof=1 ) * tstat / np.sqrt( n )

同样,我只是稍微调整了该功能以适应我当前的需求。原始函数可以在这里找到:)


0
投票

也许实现误差线的一种可疑方法是使用axes.errorbar(),但这是我的解决方法。这样做会在图表中添加另一条线,然后您需要将其与交互图的线对齐。

© www.soinside.com 2019 - 2024. All rights reserved.