我正在尝试在 Spark 中实现自定义催化剂表达式,它将数据帧中的每一列解析为字符串数组。附上玩具示例。
case class CustomExpression(child: Expression)
extends UnaryExpression {
override def dataType: DataType = ArrayType(StringType)
override protected def withNewChildInternal(newChild: Expression): Expression =
copy(child = newChild)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Generate the code to produce the desired output
val digits = ctx.freshName("digits")
val value = ctx.freshName("value")
val code =
code"""
String[] $digits = new String[]{"1", "2", "3"};
${ev.value} = $digits;
"""
ev.copy(code = code, isNull = FalseLiteral)
}
下面是如何使用它
val df = currentDf.withColumn(colName, new Column(new CustomExpression(col(colName).expr)))
当我想用
df.show()
显示 df 时,编译器给出错误 File ' generated.java', Line 32, Column 1: Expression "value_1" is not an rvalue。当我使用 df.explain("codegen") 检查生成的代码时,它给出错误 File ' generated.java', Line 56, Column 1: Expression "project_value_0" is not an rvalue。生成的代码片段是
/* 055 */ String[] project_digits_0 = new String[]{"1", "2", "3"};
/* 056 */ **project_value_0** = project_digits_0;
/* 057 */ String[] project_digits_1 = new String[]{"1", "2", "3"};
/* 058 */ project_value_2 = project_digits_1;
/* 059 */ columnartorow_mutableStateArray_3[1].reset();
/* 060 */
/* 061 */ if (false) {
/* 062 */ columnartorow_mutableStateArray_3[1].setNullAt(0);
/* 063 */ } else {
/* 064 */ // Remember the current cursor so that we can calculate how many bytes are
/* 065 */ // written later.
/* 066 */ final int project_previousCursor_0 = columnartorow_mutableStateArray_3[1].cursor();
/* 067 */
/* 068 */ final ArrayData project_tmpInput_0 = **project_value_0;**
/* 069 */ if (project_tmpInput_0 instanceof UnsafeArrayData) {
/* 070 */ columnartorow_mutableStateArray_3[1].write((UnsafeArrayData) project_tmpInput_0);
/* 071 */ } else {
/* 072 */ final int project_numElements_0 = project_tmpInput_0.numElements();
根据 doGenCode() 实现,project_value_0 是表达式代码值 ${ev.value} 的变量。如何修复错误?
在上面的玩具示例中,我期望输出 df 具有字符串数组的一列。
Quality中有很多自定义表达式,当然还有主要的spark代码CreateArray,你可以参考,但你可能只需要使用正确的类型,没有右值:
val digits = ctx.freshName("digits")
val code =
code"""
String[] $digits = new String[]{"1", "2", "3"};
ArrayData ${ev.value} = new GenericArrayData($digits);
"""
也就是说,Catalyst 中的数组不是基本类型,它们是 ArrayData。 您可以在这里看到创建:createArrayData。
为了轻松学习如何在 Catalyst 中使用数组,我建议您首先使用 CodegenFallback 并让 eval 工作。 在这种情况下,使用 new GenericArrayData(array/scala.collection.Seq) 和 $digits 可能会起作用。
使用 ArrayData 时还有其他奇怪的地方(在本示例中,您可能不会遇到它们,因为您将创建它们),此代码可以解决这些问题。
同样返回结构体必须是 InteralRow / GenericInternalRow 和 String 的 UTF8String。