我有一个基于前一天的窗口。我希望计算每一行的滚动平均值,但是如果我没有完整的一天数据,我想返回
Null
。例如
dateTime group val rolling_average
2023-01-01 00:00:00 1 1 Null
2023-01-01 00:05:00 1 2 Null
2023-01-01 00:10:00 1 3 Null
2023-01-01 00:15:00 1 4 Null
...
2023-01-02 00:00:00 1 2 2.5
所以在这里,对于前 4 行,滚动平均值将为
Null
,因为窗口内没有完整的一天。但是对于我显示的最后一行,我们将能够计算整个窗口的平均值。目前,但是我得到以下不良结果:
dateTime group val rolling_average
2023-01-01 00:00:00 1 1 1
2023-01-01 00:05:00 1 2 1.5
2023-01-01 00:10:00 1 3 2
2023-01-01 00:15:00 1 4 2.5
...
2023-01-02 00:00:00 1 2 2.5
我的窗口设置为:
window = Window.partitionBy("group").orderBy(f.col("dateTime").cast("long")).rangeBetween(-86400, 0)
我知道这可以通过后处理条件检查来完成。我可以检查创建第二个窗口
window = Window.partitionBy("cropseason_id").orderBy(f.col("dateTime").cast("long"))
然后计算行号并相应调整滚动窗口列:
(
df
.withColumn("rolling_average", f.avg(f.col("val")).over(window))
.withColumn("rn", f.row_number())
.withColumn("rolling_average", f.when(f.col("rn") < 288, f.lit(None)).otherwise(f.col("rolling_average")))
)
然而,我的问题是目前这个数据集是 5 分钟的数据,但它可能是我使用的每小时或分钟的数据,所以我需要一个通用的解决方案。我以为我可以在窗口的构造中设置它?
也许你可以计算窗口中第一条记录的时间差异,然后仅当此差异大于 84600 时才显示平均值,否则显示 null
它应该可以在任何时间间隔内正常工作,它只会跳过任何不早于窗口中第一条记录 24 小时的内容
import datetime
from pyspark.sql import Window
import pyspark.sql.functions as F
from pyspark.sql.types import *
data = [
(datetime.datetime.strptime("2023-01-01 00:00:00", "%Y-%m-%d %H:%M:%S"), 1, 1),
(datetime.datetime.strptime("2023-01-01 00:05:00", "%Y-%m-%d %H:%M:%S"), 1, 2),
(datetime.datetime.strptime("2023-01-01 06:05:00", "%Y-%m-%d %H:%M:%S"), 1, 3),
(datetime.datetime.strptime("2023-01-01 18:00:00", "%Y-%m-%d %H:%M:%S"), 1, 4),
(datetime.datetime.strptime("2023-01-02 00:00:00", "%Y-%m-%d %H:%M:%S"), 1, 2),
]
myschema = StructType(
[
StructField("dateTime", TimestampType()),
StructField("group", IntegerType()),
StructField("val", IntegerType()),
]
)
df = spark.createDataFrame(data=data, schema=myschema)
window = Window.partitionBy("group").orderBy(["dateTime"])
df = df.withColumn(
"timeDiffFromPrevInSeconds",
(
(
(F.col("dateTime")).cast("long")
- F.first("dateTime").over(window).cast("long")
)
).cast("long"),
).withColumn(
"rolling_average",
F.when(
F.col("timeDiffFromPrevInSeconds") >= 86400, F.avg(F.col("val")).over(window)
).otherwise(None),
)
df.show()
输出为:
+-------------------+-----+---+---------------+
| dateTime|group|val|rolling_average|
+-------------------+-----+---+---------------+
|2023-01-01 00:00:00| 1| 1| null|
|2023-01-01 00:05:00| 1| 2| null|
|2023-01-01 06:05:00| 1| 3| null|
|2023-01-01 18:00:00| 1| 4| null|
|2023-01-02 00:00:00| 1| 2| 2.4|
+-------------------+-----+---+---------------+