如果 Spark 中的窗口未满,有没有办法不计算窗口函数

问题描述 投票:0回答:1

我有一个基于前一天的窗口。我希望计算每一行的滚动平均值,但是如果我没有完整的一天数据,我想返回

Null
。例如

           dateTime group val rolling_average
2023-01-01 00:00:00     1   1            Null
2023-01-01 00:05:00     1   2            Null
2023-01-01 00:10:00     1   3            Null
2023-01-01 00:15:00     1   4            Null
...
2023-01-02 00:00:00     1   2             2.5

所以在这里,对于前 4 行,滚动平均值将为

Null
,因为窗口内没有完整的一天。但是对于我显示的最后一行,我们将能够计算整个窗口的平均值。目前,但是我得到以下不良结果:

           dateTime group val rolling_average
2023-01-01 00:00:00     1   1               1
2023-01-01 00:05:00     1   2             1.5
2023-01-01 00:10:00     1   3               2
2023-01-01 00:15:00     1   4             2.5
...
2023-01-02 00:00:00     1   2             2.5

我的窗口设置为:

window = Window.partitionBy("group").orderBy(f.col("dateTime").cast("long")).rangeBetween(-86400, 0)

我知道这可以通过后处理条件检查来完成。我可以检查创建第二个窗口

window = Window.partitionBy("cropseason_id").orderBy(f.col("dateTime").cast("long"))

然后计算行号并相应调整滚动窗口列:

(
    df
    .withColumn("rolling_average", f.avg(f.col("val")).over(window))
    .withColumn("rn", f.row_number())
    .withColumn("rolling_average", f.when(f.col("rn") < 288, f.lit(None)).otherwise(f.col("rolling_average")))
)

然而,我的问题是目前这个数据集是 5 分钟的数据,但它可能是我使用的每小时或分钟的数据,所以我需要一个通用的解决方案。我以为我可以在窗口的构造中设置它?

apache-spark pyspark apache-spark-sql
1个回答
0
投票

也许你可以计算窗口中第一条记录的时间差异,然后仅当此差异大于 84600 时才显示平均值,否则显示 null

它应该可以在任何时间间隔内正常工作,它只会跳过任何不早于窗口中第一条记录 24 小时的内容

import datetime
from pyspark.sql import Window
import pyspark.sql.functions as F
from pyspark.sql.types import *

data = [
    (datetime.datetime.strptime("2023-01-01 00:00:00", "%Y-%m-%d %H:%M:%S"), 1, 1),
    (datetime.datetime.strptime("2023-01-01 00:05:00", "%Y-%m-%d %H:%M:%S"), 1, 2),
    (datetime.datetime.strptime("2023-01-01 06:05:00", "%Y-%m-%d %H:%M:%S"), 1, 3),
    (datetime.datetime.strptime("2023-01-01 18:00:00", "%Y-%m-%d %H:%M:%S"), 1, 4),
    (datetime.datetime.strptime("2023-01-02 00:00:00", "%Y-%m-%d %H:%M:%S"), 1, 2),
]

myschema = StructType(
    [
        StructField("dateTime", TimestampType()),
        StructField("group", IntegerType()),
        StructField("val", IntegerType()),
    ]
)
df = spark.createDataFrame(data=data, schema=myschema)

window = Window.partitionBy("group").orderBy(["dateTime"])

df = df.withColumn(
    "timeDiffFromPrevInSeconds",
    (
        (
            (F.col("dateTime")).cast("long")
            - F.first("dateTime").over(window).cast("long")
        )
    ).cast("long"),
).withColumn(
    "rolling_average",
    F.when(
        F.col("timeDiffFromPrevInSeconds") >= 86400, F.avg(F.col("val")).over(window)
    ).otherwise(None),
)

df.show()

输出为:

+-------------------+-----+---+---------------+
|           dateTime|group|val|rolling_average|
+-------------------+-----+---+---------------+
|2023-01-01 00:00:00|    1|  1|           null|
|2023-01-01 00:05:00|    1|  2|           null|
|2023-01-01 06:05:00|    1|  3|           null|
|2023-01-01 18:00:00|    1|  4|           null|
|2023-01-02 00:00:00|    1|  2|            2.4|
+-------------------+-----+---+---------------+
© www.soinside.com 2019 - 2024. All rights reserved.