根据上一行导出值

问题描述 投票:0回答:1

我正在尝试派生新列“final”。列的值是通过引用组内的先前值派生的。在我的数据 coA、colB、colC、colD 中形成一个组,在该组中唯一会更改的值是 colE。

创建 SparkSession

from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("example") \
    .getOrCreate()

# Sample data
data = [
    ("A", "2003-03-01", 1, 11, 1, 10, 0.1),
    ("A", "2003-03-01", 1, 11, 2, 10, 0.2),
    ("A", "2003-03-01", 1, 11, 3, 10, 0.3),
    ("A", "2003-03-01", 1, 11, 4, 10, 0.1),
    ("A", "2003-03-01", 1, 11, 5, 10, 0.2),
]

# Create DataFrame
df = spark.createDataFrame(data, ["colA", "colB", "colC", "colD", "colE", "value", "pred"])

# Show DataFrame
df.show()

# Output data
output = [
    ("A", "2003-03-01", 1, 11, 1, 10, 0.1, 1),
    ("A", "2003-03-01", 1, 11, 2, 10, 0.2, 0.2),
    ("A", "2003-03-01", 1, 11, 3, 10, 0.3, 0.06),
    ("A", "2003-03-01", 1, 11, 4, 10, 0.1, 0.006),
    ("A", "2003-03-01", 1, 11, 5, 10, 0.2, 0.0012),
]

# Create DataFrame
output_df = spark.createDataFrame(output, ["colA", "colB", "colC", "colD", "colE", "value", "pred", "final"])

“最终”栏推导如下: 对于组内的第一个实例,value 等于 value * pred。组内剩余实例的值将是:Final(来自上一行)* pred。

我目前的逻辑如下:

from pyspark.sql.window import Window
window_spec = Window.partition('colA', 'colB', 'colC', 'colD').orderBy('colE')
# This will derive value for first row within each group
a1 = input_df.withColumn('final', when(lag('colE').over(window_spec).isNull(), col('pred')*col('value')
a1 = a1.withColumn('final', when(col('final').isNotNull(), col('final'))
                    .otherwise(lag(col('final')).over(window_spec) * col('pred'))))

但是,使用上述逻辑,它只为每组内的前两行生成值。

# Output data
incorrect_output = [
    ("A", "2003-03-01", 1, 11, 1, 10, 0.1, 1),
    ("A", "2003-03-01", 1, 11, 2, 10, 0.2, 0.2),
    ("A", "2003-03-01", 1, 11, 3, 10, 0.3, null),
    ("A", "2003-03-01", 1, 11, 4, 10, 0.1, null),
    ("A", "2003-03-01", 1, 11, 5, 10, 0.2, null),
]

我做错了什么?你能帮忙吗?

python apache-spark pyspark apache-spark-sql
1个回答
0
投票

看看这个:

import pyspark.sql.functions as f

df = (
    df
    .withColumn("first_value", f.first("value").over(Window.partitionBy("colA", "colB", "colC", "colD").orderBy("colE")))
    .withColumn("preds", f.collect_list("pred").over(Window.partitionBy("colA", "colB", "colC", "colD").orderBy("colE")))
    .select(
        df['*'],
        (f.col('first_value') * f.expr('aggregate(preds, cast(1 as DOUBLE), (acc, x) -> acc * x)')).cast(FloatType()).alias('Final')
    )
)

输出为:

+----+----------+----+----+----+-----+----+------+                              
|colA|      colB|colC|colD|colE|value|pred| Final|
+----+----------+----+----+----+-----+----+------+
|   A|2003-03-01|   1|  11|   1|   10| 0.1|   1.0|
|   A|2003-03-01|   1|  11|   2|   10| 0.2|   0.2|
|   A|2003-03-01|   1|  11|   3|   10| 0.3|  0.06|
|   A|2003-03-01|   1|  11|   4|   10| 0.1| 0.006|
|   A|2003-03-01|   1|  11|   5|   10| 0.2|0.0012|
+----+----------+----+----+----+-----+----+------+
© www.soinside.com 2019 - 2024. All rights reserved.