如何在DataFrame组中执行算术运算在Spark中进行聚合? [重复]

问题描述 投票:-1回答:2

这个问题在这里已有答案:

我有一个数据帧如下:

val df = Seq(("x", "y", 1),("x", "z", 2),("x", "a", 4), ("x", "a", 5), ("t", "y", 1), ("t", "y2", 6), ("t", "y3", 3), ("t", "y4", 5)).toDF("F1", "F2", "F3")


+---+---+---+
| F1| F2| F3|
+---+---+---+
|  x|  y|  1|
|  x|  z|  2|
|  x|  a|  4|
|  x|  a|  5|
|  t|  y|  1|
|  t| y2|  6|
|  t| y3|  3|
|  t| y4|  5|
+---+---+---+

如何在“F1”列上进行groupBy,并在“F3”上相乘?

总之,我可以做如下,但不确定用于乘法的函数。

df.groupBy("F1").agg(sum("F3")).show

+---+-------+
| F1|sum(F3)|
+---+-------+
|  x|     12|
|  t|     15|
+---+-------+
scala apache-spark
2个回答
1
投票

定义自定义聚合函数,如下所示:

class Product extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType =
  StructType(StructField("value", LongType) :: Nil)

// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(

    StructField("product", LongType) :: Nil
)

// This is the output type of your aggregatation function.
override def dataType: DataType = LongType

override def deterministic: Boolean = true

// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
  buffer(0) = 1L

}

// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
  buffer(0) = buffer.getAs[Long](0) * input.getAs[Long](0)
}

// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

  buffer1(0) = buffer1.getAs[Long](0) * buffer2.getAs[Long](0)
}

// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
  buffer.getLong(0)
}

}

然后在聚合中使用它,如下所示:

val product = new Product

val df = Seq(("x", "y", 1),("x", "z", 2),("x", "a", 4), ("x", "a", 5), ("t", "y", 1), ("t", "y2", 6), ("t", "y3", 3), ("t", "y4", 5)).toDF("F1", "F2", "F3")

df.groupBy("F1").agg(product(col("F3"))).show

这是输出:

+---+-----------+
| F1|product(F3)|
+---+-----------+
|  x|         40|
|  t|         90|
+---+-----------+

3
投票
val df = Seq(("x", "y", 1),("x", "z", 2),("x", "a", 4), ("x", "a", 5), ("t", "y", 1), ("t", "y2", 6), ("t", "y3", 3), ("t", "y4", 5)).toDF("F1", "F2", "F3")
import org.apache.spark.sql.Row
val x=df.select($"F1",$"F3").groupByKey{case r=>r.getString(0)}.reduceGroups{ ((r),(r2)) =>Row(r.getString(0),r.getInt(1)*r2.getInt(1)) }

x.show()

+-----+------------------------------------------+
|value|ReduceAggregator(org.apache.spark.sql.Row)|
+-----+------------------------------------------+
|    x|                                   [x, 40]|
|    t|                                   [t, 90]|
+-----+------------------------------------------+
© www.soinside.com 2019 - 2024. All rights reserved.