我有一个 PySpark DataFrame,如下所示:
df = spark.createDataFrame(
data=[
(1, "GERMANY", "20230606", True),
(2, "GERMANY", "20230620", False),
(3, "GERMANY", "20230627", True),
(4, "GERMANY", "20230705", True),
(5, "GERMANY", "20230714", False),
(6, "GERMANY", "20230715", True),
],
schema=["ID", "COUNTRY", "DATE", "FLAG"]
)
df.show()
+---+-------+--------+-----+
| ID|COUNTRY| DATE| FLAG|
+---+-------+--------+-----+
| 1|GERMANY|20230606| true|
| 2|GERMANY|20230620|false|
| 3|GERMANY|20230627| true|
| 4|GERMANY|20230705| true|
| 5|GERMANY|20230714|false|
| 6|GERMANY|20230715| true|
+---+-------+--------+-----+
DataFrame 有更多国家/地区。我想按照以下逻辑创建一个新列
COUNT_WITH_RESET
:
FLAG=False
,那么 COUNT_WITH_RESET=0
。FLAG=True
,则 COUNT_WITH_RESET
应计算从前一个日期开始的行数,其中 FLAG=False
代表该特定国家/地区。这应该是上面示例的输出。
+---+-------+--------+-----+----------------+
| ID|COUNTRY| DATE| FLAG|COUNT_WITH_RESET|
+---+-------+--------+-----+----------------+
| 1|GERMANY|20230606| true| 1|
| 2|GERMANY|20230620|false| 0|
| 3|GERMANY|20230627| true| 1|
| 4|GERMANY|20230705| true| 2|
| 5|GERMANY|20230714|false| 0|
| 6|GERMANY|20230715| true| 1|
+---+-------+--------+-----+----------------+
我尝试在窗口上使用
row_number()
,但无法重置计数。我也尝试过.rowsBetween(Window.unboundedPreceding, Window.currentRow)
。这是我的方法:
from pyspark.sql.window import Window
import pyspark.sql.functions as F
window_reset = Window.partitionBy("COUNTRY").orderBy("DATE")
df_with_reset = (
df
.withColumn("COUNT_WITH_RESET", F.when(~F.col("FLAG"), 0)
.otherwise(F.row_number().over(window_reset)))
)
df_with_reset.show()
+---+-------+--------+-----+----------------+
| ID|COUNTRY| DATE| FLAG|COUNT_WITH_RESET|
+---+-------+--------+-----+----------------+
| 1|GERMANY|20230606| true| 1|
| 2|GERMANY|20230620|false| 0|
| 3|GERMANY|20230627| true| 3|
| 4|GERMANY|20230705| true| 4|
| 5|GERMANY|20230714|false| 0|
| 6|GERMANY|20230715| true| 6|
+---+-------+--------+-----+----------------+
这显然是错误的,因为我的窗口仅按国家/地区分区,但我走在正确的轨道上吗? PySpark 中是否有特定的内置函数来实现此目的?我需要 UDF 吗?任何帮助将不胜感激。
创建有序窗口规范来对数据帧进行分区,然后计算倒置
FLAG
列上的累积和以分配组编号,以便区分以 blocks
开头的不同
false
行
W1 = Window.partitionBy('COUNTRY').orderBy('DATE')
df1 = df.withColumn('blocks', F.sum((~F.col('FLAG')).cast('long')).over(W1))
df1.show()
# +---+-------+--------+-----+------+
# | ID|COUNTRY| DATE| FLAG|blocks|
# +---+-------+--------+-----+------+
# | 1|GERMANY|20230606| true| 0|
# | 2|GERMANY|20230620|false| 1|
# | 3|GERMANY|20230627| true| 1|
# | 4|GERMANY|20230705| true| 1|
# | 5|GERMANY|20230714|false| 2|
# | 6|GERMANY|20230715| true| 2|
# +---+-------+--------+-----+------+
按
COUNTRY
和 blocks
对数据帧进行分区,然后计算有序分区上的行号以创建顺序计数器
W2 = Window.partitionBy('COUNTRY', 'blocks').orderBy('DATE')
df1 = df1.withColumn('COUNT_WITH_RESET', F.row_number().over(W2) - 1)
df1.show()
# +---+-------+--------+-----+------+----------------+
# | ID|COUNTRY| DATE| FLAG|blocks|COUNT_WITH_RESET|
# +---+-------+--------+-----+------+----------------+
# | 1|GERMANY|20230606| true| 0| 0|
# | 2|GERMANY|20230620|false| 1| 0|
# | 3|GERMANY|20230627| true| 1| 1|
# | 4|GERMANY|20230705| true| 1| 2|
# | 5|GERMANY|20230714|false| 2| 0|
# | 6|GERMANY|20230715| true| 2| 1|
# +---+-------+--------+-----+------+----------------+