如何使用groupby+transform代替pipe?

问题描述 投票:0回答:2

假设我有一个像这样的数据框

import pandas as pd
from scipy import stats

df = pd.DataFrame(
    {
        'group': list('abaab'),
        'val1': range(5),
        'val2': range(2, 7),
        'val3': range(4, 9)
    }
)

  group  val1  val2  val3
0     a     0     2     4
1     b     1     3     5
2     a     2     4     6
3     a     3     5     7
4     b     4     6     8

现在我想使用

group
列中的两个(可能是所有对,所以我不想在任何地方硬编码列名称)计算
vali
列中每个组的线性回归。

潜在的实现可能看起来像这样基于

pipe

def do_lin_reg_pipe(df, group_col, col1, col2):
    group_names = df[group_col].unique()
    df_subsets = []
    for s in group_names:
        df_subset = df.loc[df[group_col] == s]
        x = df_subset[col1].values
        y = df_subset[col2].values
        slope, intercept, r, p, se = stats.linregress(x, y)
        df_subset = df_subset.assign(
            slope=slope,
            intercept=intercept,
            r=r,
            p=p,
            se=se
        )
        df_subsets.append(df_subset)
    return pd.concat(df_subsets)

然后我就可以使用

df_linreg_pipe = (
    df.pipe(do_lin_reg_pipe, group_col='group', col1='val1', col2='val3')
      .assign(p=lambda d: d['p'].round(3))
)

这给出了期望的结果

  group  val1  val2  val3  slope  intercept    r    p   se
0     a     0     2     4    1.0        4.0  1.0  0.0  0.0
2     a     2     4     6    1.0        4.0  1.0  0.0  0.0
3     a     3     5     7    1.0        4.0  1.0  0.0  0.0
1     b     1     3     5    1.0        4.0  1.0  0.0  0.0
4     b     4     6     8    1.0        4.0  1.0  0.0  0.0

我不喜欢的是我必须循环遍历组,使用 和

append
,然后还使用
concat
,所以我想我应该以某种方式使用
groupby
transform
但我不明白这个工作。函数调用应该是这样的

df_linreg_transform = df.copy()
df_linreg_transform[['slope', 'intercept', 'r', 'p', 'se']] = (
    df.groupby('group').transform(do_lin_reg_transform, col1='val1', col2='val3')
)

问题是如何定义

do_lin_reg_transform
;我想要一些类似的东西

def do_lin_reg_transform(df, col1, col2):
    
    x = df[col1].values
    y = df[col2].values
    slope, intercept, r, p, se = stats.linregress(x, y)

    return (slope, intercept, r, p, se)

但是 - 当然 - 会崩溃

KeyError

密钥错误:'val1'

如何实现

do_lin_reg_transform
使其与
groupby
transform
一起使用?

python pandas dataframe pandas-groupby
2个回答
3
投票

由于计算需要访问多个列而无法使用

groupby...transform
,因此我们的想法是使用
groupby...apply
map
将结果广播到每一行:

cols = ['slope', 'intercept', 'r', 'p', 'se']
lingress = lambda x: stats.linregress(x['val1'], x['val3'])

df[cols] = pd.DataFrame.from_records(df['group'].map(df.groupby('group').apply(lingress)))
print(df)

# Output
  group  val1  val2  val3  slope  intercept    r             p   se
0     a     0     2     4    1.0        4.0  1.0  9.003163e-11  0.0
1     b     1     3     5    1.0        4.0  1.0  0.000000e+00  0.0
2     a     2     4     6    1.0        4.0  1.0  9.003163e-11  0.0
3     a     3     5     7    1.0        4.0  1.0  9.003163e-11  0.0
4     b     4     6     8    1.0        4.0  1.0  0.000000e+00  0.0

1
投票

转换旨在聚合单列的结果。回归需要多个,因此您应该使用

apply

如果您愿意,您可以定义聚合以返回 DataFrame 而不是 Series(因此结果不会减少)。为此,您需要确保索引是唯一的。然后

concat
返回结果,使其与索引对齐。如果分组列超过 1 个,您不会遇到任何问题。

def group_reg(gp, col1, col2):
    df = pd.DataFrame([stats.linregress(gp[col1], gp[col2])]*len(gp), 
                      columns=['slope', 'intercept', 'r', 'p', 'se'],
                      index=gp.index)
    return df

pd.concat([df, df.groupby('group').apply(group_reg, col1='val1', col2='val3')], axis=1)

  group  val1  val2  val3  slope  intercept    r             p   se
0     a     0     2     4    1.0        4.0  1.0  9.003163e-11  0.0
1     b     1     3     5    1.0        4.0  1.0  0.000000e+00  0.0
2     a     2     4     6    1.0        4.0  1.0  9.003163e-11  0.0
3     a     3     5     7    1.0        4.0  1.0  9.003163e-11  0.0
4     b     4     6     8    1.0        4.0  1.0  0.000000e+00  0.0
© www.soinside.com 2019 - 2024. All rights reserved.