我有这段代码:
shifted_pd = account_level_pd_shifts.filter(account_level_pd_shifts['WITHIN_PD_EXCLUSION'] == True).groupBy(['FORWARD_LOOK_MODEL', 'FOR_PD_TYPE']).agg(f.avg('PD_SHIFT').alias('SHIFTED_PD'))
account_level_pd_shifts = account_level_pd_shifts.drop('SHIFTED_PD').join(f.broadcast(shifted_pd), on=['FORWARD_LOOK_MODEL', 'FOR_PD_TYPE'], how='left')
我无法计算groupby的平均值,只能取第一个值,然后使用f.lit()将其添加到新列中,因为分组的数据有4个项目,不是1表示我得到,我得到4意味着每个。
这两行代码在23个循环迭代中,并且23个连接不是很好。有没有一种方法可以避免这种情况,并以某种方式直接添加每个组的平均值,或者不加入大数据框而直接添加?
如果不清楚,我可以提供更多信息:)
感谢您的帮助
您正在寻找的是窗口功能。您想计算Window
上的条件平均值:
from pyspark.sql import Window
from pyspark.sql.functions import col, when, avg
w = Window.partitionBy('FORWARD_LOOK_MODEL', 'FOR_PD_TYPE')
account_level_pd_shifts = account_level_pd_shifts.withColumn("SHIFTED_PD",
avg(when(col("WITHIN_PD_EXCLUSION"),
col("PD_SHIFT")
)
).over(w)
)