基本上,最终目标是创建类似
dollarSum
的东西,它将返回与 ROUND(SUM(col), 2)
相同的值。
我正在使用 Databricks 运行时 10.4 LTS ML,它显然对应于 Spark 3.2.1 和 Scala 2.12。
我能够遵循 UDAF 的教程/示例代码,并使用它来创建类似于内置
EVERY
函数的东西。但这似乎更像ImperativeAggregate
,而我想要的可能更像DeclarativeAggregate
,参见。 Spark源码中的注释.
总的来说,我无法在网上找到任何关于如何以简单的方式扩展内置聚合函数的文档,在这种方式中,您只需修改“完成”或“评估”步骤,甚至只需添加额外的内容即可行为。
到目前为止我尝试过的: 到目前为止,我已经尝试了至少四件事,但都不起作用。
尝试1:
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions.{sum, round}
object dollarSum extends Aggregator[Double, Double, Double] {
def zero: Double = sum.zero
def reduce(buffer: Double, row: Double): Double = sum.reduce
def merge(buffer1: Double, buffer2: Double) Double = sum.merge
def finish(reduction: Double): Double = {
sum.finish(reduction)
round(reduction, 2)
}
def bufferEncoder: Encoder[Double] = sum.bufferEncoder
def outputEncoder: Encoder[Double] = sum.outputEncoder
}
尝试2:我尝试从这里复制粘贴修改代码。这似乎失败了,因为内置
Sum
类的大多数属性和方法似乎都是私有的(可能是因为开发人员不希望像我这样不知道自己在做什么的人破坏代码)。但我不知道我可以使用什么公共接口/API 来获得我想要的东西。
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.functions.round
import org.apache.spark.sql.catalyst.expressions.EvalMode
import org.apache.spark.sql.types.DecimalType
trait dollarSum extends Sum {
override lazy val evaluateExpression: Expression = {
Sum.resultType match {
case d: DecimalType =>
val checkOverflowInSum =
CheckOverflowInSum(Sum.sum, d, evalMode != EvalMode.ANSI, getContextOrNull())
If(isEmpty, Literal.create(null, Sum.resultType), checkOverflowInSum)
case _ if shouldTrackIsEmpty =>
If(isEmpty, Literal.create(null, Sum.resultType), Sum.sum)
case _ => round(Sum.sum, 2)
}
}
}
由于其他一些缺失的导入,这可能仍然会失败,但由于尝试访问可能不应该访问的私有方法和属性,我再次无法在调试中走得那么远。
尝试 3:同一文件中
try_sum
的源代码似乎更接近于使用“公共 API”求和,所以我尝试复制粘贴修改它。但ExpressionBuilder
也似乎是一个私人课程,所以这也失败了。
import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder
import org.apache.spark.sql.catalyst.expressions.Expression
object DollarSumExpressionBuilder extends ExpressionBuilder {
override def build(funcName: String, expressions: Seq[Expression]): Expression = {
val numArgs = expressions.length
if (numArgs == 1) {
round(Sum(expressions.head),2)
} else {
throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(1, 2), numArgs)
}
}
}
然后的想法是,如果可行,我会尝试以与源代码中的 Spark SQL
注册
TRY_SUM
相同的方式注册该函数,参见在这里。但是我收到了关于 ExpressionBuilder
不存在的错误,这似乎表明它也是包的私有类,因此不是我可以用来扩展 SUM
的公共接口。
我也不清楚
SUM
构造函数的返回类型是什么,我认为它可能是 AggregateExpression
继承自 Expression
。而且我不确定 round
的输入类型是什么,看起来可能是 org.apache.spark.sql.Column
,如果是这样,我不知道如何从 Expression
转换为 Column
。
例如是否在上面
round(org.apache.spark.sql.Column((Sum(expressions.head)),2)
或
round(org.apache.spark.sql.functions.col((Sum(expressions.head)),2)
将能够实现所需的类型转换(似乎都不起作用)。
尝试4: 沿着上述思路,不知道需要哪些类型以及如何在它们之间进行转换,以及
SUM
的公共接口是什么,我尝试使用 org.apache.spark.sql.functions.sum
作为 SUM
的“公共接口”,但是这也不起作用。
具体
import org.apache.spark.sql.functions.{round, sum}
import org.apache.spark.sql.Column
// originally I had `expression: org.apache.spark.sql.catalyst.expressions.Expression` but that didn't work
def dollarSum(expression: Column): Column = {round(sum(expression), 2)}
实际上不会抛出任何错误,但是当我尝试将结果对象实际注册为(n聚合)函数时,它失败了,特别是
spark.udf.register("dollar_sum", functions.udaf(dollarSum))
不起作用,也不起作用
spark.udf.register("dollar_sum", functions.udf(dollarSum))
哇,这个问题中有很多有趣的东西,而且非常熟悉:Quality 的 agg_expr 是我进入这个领域的旅程。
要构建自定义表达式,您可能需要将代码放入 org.apache.spark.sql 包中,例如注册函数。使用 SparkSession 实例 FunctionRegistry createOrReplaceTempFunction (例如 SparkSession.getActiveSession.get.sessionState.functionRegistry),您可以在会话中使用该函数。如果您在配置单元视图等中需要它,则必须使用 SparkSessionExtensions 作为作用域和 FunctionRegistry.builtin.registerFunction。
实际注册的ExpressionBuilder只是Seq[Expression] => Expression的别名,代表传入构造表达式的参数。
因此,根据 Spark 版本(内部 api 发生很大变化):
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Round, Literal, EvalMode}
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
SparkSession.getActiveSession.get.sessionState.functionRegistry.
createOrReplaceTempFunction("dollarSum", exps => Round(
Sum(exps.head, EvalMode.TRY).toAggregateExpression(), Literal(2)), "built-in")
val seq = Seq(1.245, 242.535, 65656.234425, 2343.666)
import sparkSession.implicits._
seq.toDF("amount")//.selectExpr("round(sum(amount), 2)").show
.selectExpr("dollarSum(amount)").show