PySpark:通过重置对窗口进行计数

问题描述 投票:0回答:1

我有一个 PySpark DataFrame,如下所示:

df = spark.createDataFrame(
    data=[
    (1, "GERMANY", "20230606", True),
    (2, "GERMANY", "20230620", False),
    (3, "GERMANY", "20230627", True),
    (4, "GERMANY", "20230705", True),
    (5, "GERMANY", "20230714", False),
    (6, "GERMANY", "20230715", True),
    ],
    schema=["ID", "COUNTRY", "DATE", "FLAG"]
)
df.show()
+---+-------+--------+-----+
| ID|COUNTRY|    DATE| FLAG|
+---+-------+--------+-----+
|  1|GERMANY|20230606| true|
|  2|GERMANY|20230620|false|
|  3|GERMANY|20230627| true|
|  4|GERMANY|20230705| true|
|  5|GERMANY|20230714|false|
|  6|GERMANY|20230715| true|
+---+-------+--------+-----+

DataFrame 有更多国家/地区。我想按照以下逻辑创建一个新列

COUNT_WITH_RESET

  • 如果
    FLAG=False
    ,那么
    COUNT_WITH_RESET=0
  • 如果
    FLAG=True
    ,则
    COUNT_WITH_RESET
    应计算从前一个日期开始的行数,其中
    FLAG=False
    代表该特定国家/地区。

这应该是上面示例的输出。

+---+-------+--------+-----+----------------+
| ID|COUNTRY|    DATE| FLAG|COUNT_WITH_RESET|
+---+-------+--------+-----+----------------+
|  1|GERMANY|20230606| true|               1|
|  2|GERMANY|20230620|false|               0|
|  3|GERMANY|20230627| true|               1|
|  4|GERMANY|20230705| true|               2|
|  5|GERMANY|20230714|false|               0|
|  6|GERMANY|20230715| true|               1|
+---+-------+--------+-----+----------------+

我尝试在窗口上使用

row_number()
,但无法重置计数。我也尝试过
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
。这是我的方法:

from pyspark.sql.window import Window
import pyspark.sql.functions as F

window_reset = Window.partitionBy("COUNTRY").orderBy("DATE")

df_with_reset = (
    df
    .withColumn("COUNT_WITH_RESET", F.when(~F.col("FLAG"), 0)
                .otherwise(F.row_number().over(window_reset)))
)

df_with_reset.show()
+---+-------+--------+-----+----------------+
| ID|COUNTRY|    DATE| FLAG|COUNT_WITH_RESET|
+---+-------+--------+-----+----------------+
|  1|GERMANY|20230606| true|               1|
|  2|GERMANY|20230620|false|               0|
|  3|GERMANY|20230627| true|               3|
|  4|GERMANY|20230705| true|               4|
|  5|GERMANY|20230714|false|               0|
|  6|GERMANY|20230715| true|               6|
+---+-------+--------+-----+----------------+

这显然是错误的,因为我的窗口仅按国家/地区分区,但我走在正确的轨道上吗? PySpark 中是否有特定的内置函数来实现此目的?我需要 UDF 吗?任何帮助将不胜感激。

pyspark count spark-window-function
1个回答
0
投票

创建有序窗口规范来对数据帧进行分区,然后计算倒置

FLAG
列上的累积和以分配组编号,以便区分以
blocks
 开头的不同 
false

W1 = Window.partitionBy('COUNTRY').orderBy('DATE')
df1 = df.withColumn('blocks', F.sum((~F.col('FLAG')).cast('long')).over(W1))

df1.show()
# +---+-------+--------+-----+------+
# | ID|COUNTRY|    DATE| FLAG|blocks|
# +---+-------+--------+-----+------+
# |  1|GERMANY|20230606| true|     0|
# |  2|GERMANY|20230620|false|     1|
# |  3|GERMANY|20230627| true|     1|
# |  4|GERMANY|20230705| true|     1|
# |  5|GERMANY|20230714|false|     2|
# |  6|GERMANY|20230715| true|     2|
# +---+-------+--------+-----+------+

COUNTRY
blocks
对数据帧进行分区,然后计算有序分区上的行号以创建顺序计数器

W2 = Window.partitionBy('COUNTRY', 'blocks').orderBy('DATE')
df1 = df1.withColumn('COUNT_WITH_RESET', F.row_number().over(W2) - 1)


df1.show()
# +---+-------+--------+-----+------+----------------+
# | ID|COUNTRY|    DATE| FLAG|blocks|COUNT_WITH_RESET|
# +---+-------+--------+-----+------+----------------+
# |  1|GERMANY|20230606| true|     0|               0|
# |  2|GERMANY|20230620|false|     1|               0|
# |  3|GERMANY|20230627| true|     1|               1|
# |  4|GERMANY|20230705| true|     1|               2|
# |  5|GERMANY|20230714|false|     2|               0|
# |  6|GERMANY|20230715| true|     2|               1|
# +---+-------+--------+-----+------+----------------+
© www.soinside.com 2019 - 2024. All rights reserved.