我有一个像这样的数据框(但更大):
id start end
0 10 20
1 11 13
2 14 18
3 22 30
4 25 27
5 28 31
我正在尝试有效地合并 PySpark 中的重叠间隔,同时保存在新列“ids”中,其中间隔已合并,因此它看起来像这样:
start end ids
10 20 [0,1,2]
22 31 [3,4,5]
可视化:
来自:
我可以在不使用 udf 的情况下完成此操作吗?
编辑:id和start的顺序不一定相同。
您可以使用窗口函数将先前行与当前行进行比较,以构建一个列来确定当前行是否是新间隔的开始,然后对该列求和以构建一个间隔 id。然后您按此间隔 ID 进行分组以获得最终的数据帧。
如果您调用
input_df
您的输入数据框,代码将如下:
from pyspark.sql import Window
from pyspark.sql import functions as F
all_previous_rows_window = Window \
.orderBy('start') \
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
result = input_df \
.withColumn('max_previous_end', F.max('end').over(all_previous_rows_window)) \
.withColumn('interval_change', F.when(
F.col('start') > F.lag('max_previous_end').over(Window.orderBy('start')),
F.lit(1)
).otherwise(F.lit(0))) \
.withColumn('interval_id', F.sum('interval_change').over(all_previous_rows_window)) \
.drop('interval_change', 'max_previous_end') \
.groupBy('interval_id') \
.agg(
F.collect_list('id').alias('ids'),
F.min('start').alias('start'),
F.max('end').alias('end')
).drop('interval_id')
因此您可以合并间隔,而无需任何用户定义的函数。然而,每次我们使用窗口时,代码仅在一个执行器上执行,因为我们的窗口没有分区。