我下面有一个数据框
data = [("A", "2022-01-01", 10.0),
("A", "2022-01-02", None),
("A", "2022-01-03", None),
("A", "2022-01-04", 40.0),
("B", "2022-01-01", 30.0),
("B", "2022-01-02", None),
("B", "2022-01-03", 60.0)]
列= [“组”,“日期”,“值”] df = Spark.createDataFrame(数据, 列)
我想按“组”列进行分组,然后按日期线性插值值。 我如何在 Pyspark 2.4 中做到这一点? 感谢您提前回复。
我尝试通过 (next_value - prev_value)/(next_date - prev_date)*(next_date - curr_date) 来计算它。但我找不到一个函数来获取最接近的可用值。我发现火花功能滞后和领先。但这不是我需要的。
我尝试使用此代码来计算前一个非空值和下一个非空值。但当有 2 个连续的空值时,它不起作用。
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.functions import when, col, lit
from pyspark.sql.types import DoubleType
data = [("A", "2022-01-01", 10.0),
("A", "2022-01-02", None),
("A", "2022-01-03", None),
("A", "2022-01-04", 40.0),
("B", "2022-01-01", 30.0),
("B", "2022-01-02", None),
("B", "2022-01-03", 60.0)]
columns = ["group", "date", "value"]
df = spark.createDataFrame(data, columns)
df = df.withColumn("date", df["date"].cast("date"))
window_spec = Window.partitionBy("group").orderBy("date")
df = df.withColumn("prev_value", sf.lag(df["value"]).over(window_spec))
df = df.withColumn("next_value", sf.lead(df["value"]).over(window_spec))
df= (
df
.withColumn("prev_date", sf.lag(df["value"].isNotNull()).over(window_spec))
.withColumn("next_date", sf.lead(df["value"].isNotNull()).over(window_spec))
)
我希望得到下面的结果。
+-----+----------+-----+----------+----------+----------+----------
+
|group| date|value|prev_value|next_value| prev_date| next_date|
+-----+----------+-----+----------+----------+----------+----------
+
| B|2022-01-01| 30.0| null| 60.0| null|2022-01-03|
| B|2022-01-02| null| 30.0| 60.0|2022-01-01|2022-01-03|
| B|2022-01-03| 60.0| 30.0| null|2022-01-01| null|
| A|2022-01-01| 10.0| null| 40.0| null|2022-01-04|
| A|2022-01-02| null| 10.0| 40.0|2022-01-01|2022-01-04|
| A|2022-01-03| null| 10.0| 40.0|2022-01-01|2022-01-04|
| A|2022-01-04| 40.0| 10.0| null|2022-01-01| null|
+-----+----------+-----+----------+----------+----------+----------+
或者还有其他方法可以对 Pyspark 2.4 中的值进行线性插值吗?
在 Spark SQL 中实现良好的线性插值算法可能很困难。
如果单个组的行数不太大(可能只有几 10000 行),可以将这些行收集到一个数组中(使用
groupBy
和 collect_list
作为聚合函数),然后 UDF 可以执行以下操作:使用您最喜欢的库进行数字魔法,例如 numpy.interp。
定义 UDF
from pyspark.sql import functions as F
from pyspark.sql import types as T
def interpolate(xy):
import numpy as np
a=np.array(xy, dtype=np.float64)
a=a[a[:, 0].argsort()]
vals=a[~np.isnan(a[:,1])]
x=a[:,0]
xp=vals[:,0]
fp=vals[:,1]
res = np.interp(x=x, xp=xp, fp=fp)
return np.column_stack((x, res)).tolist()
interpolate_udf = F.udf(interpolate, T.ArrayType(T.ArrayType(T.FloatType())))
致电UDF
df.withColumn('date_ts', F.unix_timestamp('date')) \
.withColumn('xy', F.array('date_ts', 'value')) \
.groupBy('group').agg(F.collect_list('xy').alias('xy')) \
.withColumn('result', interpolate_udf('xy') ) \
.withColumn('xy', F.explode('result'))\
.withColumn('date', F.from_unixtime(F.col('xy')[0]).cast('date')) \
.withColumn('value', F.col('xy')[1]) \
.drop('xy', 'result') \
.show()
结果:
+-----+----------+-----+
|group|date |value|
+-----+----------+-----+
|A |2022-01-01|10.0 |
|A |2022-01-02|20.0 |
|A |2022-01-03|30.0 |
|A |2022-01-04|40.0 |
|B |2022-01-01|30.0 |
|B |2022-01-02|45.0 |
|B |2022-01-03|60.0 |
+-----+----------+-----+