我正在尝试派生新列“final”。列的值是通过引用组内的先前值派生的。在我的数据 coA、colB、colC、colD 中形成一个组,在该组中唯一会更改的值是 colE。
from pyspark.sql import SparkSession
spark = SparkSession.builder \
.appName("example") \
.getOrCreate()
# Sample data
data = [
("A", "2003-03-01", 1, 11, 1, 10, 0.1),
("A", "2003-03-01", 1, 11, 2, 10, 0.2),
("A", "2003-03-01", 1, 11, 3, 10, 0.3),
("A", "2003-03-01", 1, 11, 4, 10, 0.1),
("A", "2003-03-01", 1, 11, 5, 10, 0.2),
]
# Create DataFrame
df = spark.createDataFrame(data, ["colA", "colB", "colC", "colD", "colE", "value", "pred"])
# Show DataFrame
df.show()
# Output data
output = [
("A", "2003-03-01", 1, 11, 1, 10, 0.1, 1),
("A", "2003-03-01", 1, 11, 2, 10, 0.2, 0.2),
("A", "2003-03-01", 1, 11, 3, 10, 0.3, 0.06),
("A", "2003-03-01", 1, 11, 4, 10, 0.1, 0.006),
("A", "2003-03-01", 1, 11, 5, 10, 0.2, 0.0012),
]
# Create DataFrame
output_df = spark.createDataFrame(output, ["colA", "colB", "colC", "colD", "colE", "value", "pred", "final"])
“最终”栏推导如下: 对于组内的第一个实例,value 等于 value * pred。组内剩余实例的值将是:Final(来自上一行)* pred。
我目前的逻辑如下:
from pyspark.sql.window import Window
window_spec = Window.partition('colA', 'colB', 'colC', 'colD').orderBy('colE')
# This will derive value for first row within each group
a1 = input_df.withColumn('final', when(lag('colE').over(window_spec).isNull(), col('pred')*col('value')
a1 = a1.withColumn('final', when(col('final').isNotNull(), col('final'))
.otherwise(lag(col('final')).over(window_spec) * col('pred'))))
但是,使用上述逻辑,它只为每组内的前两行生成值。
# Output data
incorrect_output = [
("A", "2003-03-01", 1, 11, 1, 10, 0.1, 1),
("A", "2003-03-01", 1, 11, 2, 10, 0.2, 0.2),
("A", "2003-03-01", 1, 11, 3, 10, 0.3, null),
("A", "2003-03-01", 1, 11, 4, 10, 0.1, null),
("A", "2003-03-01", 1, 11, 5, 10, 0.2, null),
]
我做错了什么?你能帮忙吗?
看看这个:
import pyspark.sql.functions as f
df = (
df
.withColumn("first_value", f.first("value").over(Window.partitionBy("colA", "colB", "colC", "colD").orderBy("colE")))
.withColumn("preds", f.collect_list("pred").over(Window.partitionBy("colA", "colB", "colC", "colD").orderBy("colE")))
.select(
df['*'],
(f.col('first_value') * f.expr('aggregate(preds, cast(1 as DOUBLE), (acc, x) -> acc * x)')).cast(FloatType()).alias('Final')
)
)
输出为:
+----+----------+----+----+----+-----+----+------+
|colA| colB|colC|colD|colE|value|pred| Final|
+----+----------+----+----+----+-----+----+------+
| A|2003-03-01| 1| 11| 1| 10| 0.1| 1.0|
| A|2003-03-01| 1| 11| 2| 10| 0.2| 0.2|
| A|2003-03-01| 1| 11| 3| 10| 0.3| 0.06|
| A|2003-03-01| 1| 11| 4| 10| 0.1| 0.006|
| A|2003-03-01| 1| 11| 5| 10| 0.2|0.0012|
+----+----------+----+----+----+-----+----+------+