应用滞后函数和.when条件在Pyspark中创建新变量

问题描述 投票:0回答:1

我正在尝试创建一个新变量“Round”,以指示治疗轮次,该治疗轮次以每次治疗“Event”、“Gap_Day”之间错过的天数为条件。如果“Gap_Day”的值小于 10 天,我尝试使用 lag() 函数结转先前的“Round”值,否则它将开始下一个治疗轮次(上一个治疗轮次 + 1)。

我正在使用窗口函数按患者“ID”对数据进行分区,并按治疗“事件”排序。

新变量的创建工作正常,但仅将其自身应用于一行,窗口中的其余值保持为空。如何将变量创建应用于所有行?

from pyspark.sql.functions import col, lag, lead, when, first, last
from pyspark.sql.window import Window

data = sc.parallelize([
    ('A', 1, None),
    ('A', 2, 3),
    ('A', 3, 13),
    ('A', 4, 4),
    ('B', 1, None),
    ('B', 2, 22),
    ('B', 3, 3),
    ('B', 4, 14),
    ('B', 5, 11),
    ])

df_reprex = spark.createDataFrame(data, ['ID', 'Event', 'Gap_Day'])

window_def = Window.partitionBy("ID").orderBy("Event")

df_reprex = df_reprex.withColumn("Round", when(col("Gap_Day").isNull(), 1))
df_reprex = df_reprex.withColumn("Round", when((col("Gap_Day") < 10), 
                                                   lag("Round").over(window_def))
                                          .when((col("Gap_Day") >= 10), 
                                                   lag("Round").over(window_def) + 1)
                                          .otherwise(col("Round")))

df_reprex.show()

这是“Round”的输出以及我添加的一列以显示预期的输出应该是什么。

+---+-----+-------+-----+--------------+
| ID|Event|Gap_Day|Round|Round_Expected|
+---+-----+-------+-----+--------------+
|  A|    1|   null|    1|             1|
|  A|    2|      3|    1|             1|
|  A|    3|     13| null|             2|
|  A|    4|      4| null|             2|
|  B|    1|   null|    1|             1|
|  B|    2|     22|    2|             2|
|  B|    3|      3| null|             2|
|  B|    4|     14| null|             3|
|  B|    5|     11| null|             4|
+---+-----+-------+-----+--------------+

我尝试用以下方法调整窗口:

window_def = Window.partitionBy("ID").orderBy("Event").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)

但这给出了以下错误消息:

AnalysisException: Cannot specify window frame for lag function.

提前感谢您的建议!

apache-spark pyspark window-functions lag
1个回答
0
投票

我认为你正在尝试做 cumsum 并且你不需要

lag

window_def = Window.partitionBy('ID').orderBy('Event')
df_reprex = df_reprex.withColumn('Round', 
    sum(
        when(col('Gap_Day') < 10, 0).otherwise(1)
    ).over(window_def))
© www.soinside.com 2019 - 2024. All rights reserved.