使用Pyspark进行线性插值

问题描述 投票:0回答:1

我下面有一个数据框

data = [("A", "2022-01-01", 10.0),
    ("A", "2022-01-02", None),
    ("A", "2022-01-03", None),
    ("A", "2022-01-04", 40.0),
    ("B", "2022-01-01", 30.0),
    ("B", "2022-01-02", None),
    ("B", "2022-01-03", 60.0)]

列= [“组”,“日期”,“值”] df = Spark.createDataFrame(数据, 列)

我想按“组”列进行分组,然后按日期线性插值值。 我如何在 Pyspark 2.4 中做到这一点? 感谢您提前回复。

我尝试通过 (next_value - prev_value)/(next_date - prev_date)*(next_date - curr_date) 来计算它。但我找不到一个函数来获取最接近的可用值。我发现火花功能滞后和领先。但这不是我需要的。

我尝试使用此代码来计算前一个非空值和下一个非空值。但当有 2 个连续的空值时,它不起作用。

from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.functions import when, col, lit
from pyspark.sql.types import DoubleType


data = [("A", "2022-01-01", 10.0),
        ("A", "2022-01-02", None),
        ("A", "2022-01-03", None),
        ("A", "2022-01-04", 40.0),
        ("B", "2022-01-01", 30.0),
        ("B", "2022-01-02", None),
        ("B", "2022-01-03", 60.0)]
columns = ["group", "date", "value"]
df = spark.createDataFrame(data, columns)

df = df.withColumn("date", df["date"].cast("date"))

window_spec = Window.partitionBy("group").orderBy("date")

df = df.withColumn("prev_value", sf.lag(df["value"]).over(window_spec))
df = df.withColumn("next_value", sf.lead(df["value"]).over(window_spec))

df= (
    df
    .withColumn("prev_date", sf.lag(df["value"].isNotNull()).over(window_spec))
    .withColumn("next_date", sf.lead(df["value"].isNotNull()).over(window_spec))
)

我希望得到下面的结果。

+-----+----------+-----+----------+----------+----------+---------- 
+
|group|      date|value|prev_value|next_value| prev_date| next_date|
+-----+----------+-----+----------+----------+----------+---------- 
+
|    B|2022-01-01| 30.0|      null|      60.0|      null|2022-01-03|
|    B|2022-01-02| null|      30.0|      60.0|2022-01-01|2022-01-03|
|    B|2022-01-03| 60.0|      30.0|      null|2022-01-01|      null|
|    A|2022-01-01| 10.0|      null|      40.0|      null|2022-01-04|
|    A|2022-01-02| null|      10.0|      40.0|2022-01-01|2022-01-04|
|    A|2022-01-03| null|      10.0|      40.0|2022-01-01|2022-01-04|
|    A|2022-01-04| 40.0|      10.0|      null|2022-01-01|      null|
+-----+----------+-----+----------+----------+----------+----------+

或者还有其他方法可以对 Pyspark 2.4 中的值进行线性插值吗?

pyspark apache-spark-sql interpolation
1个回答
0
投票

在 Spark SQL 中实现良好的线性插值算法可能很困难。

如果单个组的行数不太大(可能只有几 10000 行),可以将这些行收集到一个数组中(使用

groupBy
collect_list
作为聚合函数),然后 UDF 可以执行以下操作:使用您最喜欢的库进行数字魔法,例如 numpy.interp

定义 UDF

from pyspark.sql import functions as F
from pyspark.sql import types as T

def interpolate(xy):
    import numpy as np
    a=np.array(xy, dtype=np.float64)
    a=a[a[:, 0].argsort()]
    vals=a[~np.isnan(a[:,1])]
    x=a[:,0]
    xp=vals[:,0]
    fp=vals[:,1]
    res = np.interp(x=x, xp=xp, fp=fp)
    return np.column_stack((x, res)).tolist()

interpolate_udf = F.udf(interpolate, T.ArrayType(T.ArrayType(T.FloatType())))

致电UDF

df.withColumn('date_ts', F.unix_timestamp('date')) \
  .withColumn('xy', F.array('date_ts', 'value')) \
  .groupBy('group').agg(F.collect_list('xy').alias('xy')) \
  .withColumn('result', interpolate_udf('xy') ) \
  .withColumn('xy', F.explode('result'))\
  .withColumn('date', F.from_unixtime(F.col('xy')[0]).cast('date')) \
  .withColumn('value', F.col('xy')[1]) \
  .drop('xy', 'result') \
  .show()

结果:

+-----+----------+-----+
|group|date      |value|
+-----+----------+-----+
|A    |2022-01-01|10.0 |
|A    |2022-01-02|20.0 |
|A    |2022-01-03|30.0 |
|A    |2022-01-04|40.0 |
|B    |2022-01-01|30.0 |
|B    |2022-01-02|45.0 |
|B    |2022-01-03|60.0 |
+-----+----------+-----+
© www.soinside.com 2019 - 2024. All rights reserved.