我正在尝试从字符串的Seq动态地将列添加到DataFrame。
这里是一个例子:源数据帧就像:
+-----+---+----+---+---+
|id | A | B | C | D |
+-----+---+----+---+---+
|1 |toto|tata|titi| |
|2 |bla |blo | | |
|3 |b | c | a | d |
+-----+---+----+---+---+
我还有一个字符串的Seq,其中包含我要添加的列的名称。如果源DataFrame中已经存在一列,则它必须进行某种区别,如下所示:
Seq看起来像:
val columns = Seq("A", "B", "F", "G", "H")
期望是:
+-----+---+----+---+---+---+---+---+
|id | A | B | C | D | F | G | H |
+-----+---+----+---+---+---+---+---+
|1 |toto|tata|titi|tutu|null|null|null
|2 |bla |blo | | |null|null|null|
|3 |b | c | a | d |null|null|null|
+-----+---+----+---+---+---+---+---+
到目前为止,我所做的是这样的:
val difference = columns diff sourceDF.columns
val finalDF = difference.foldLeft(sourceDF)((df, field) => if (!sourceDF.columns.contains(field)) df.withColumn(field, lit(null))) else df)
.select(columns.head, columns.tail:_*)
但是我无法弄清楚如何以更简单,更轻松的方式有效地使用Spark进行此操作...
提前感谢
下面将根据您的逻辑进行优化。
scala> df.show
+---+----+----+----+----+
| id| A| B| C| D|
+---+----+----+----+----+
| 1|toto|tata|titi|null|
| 2| bla| blo|null|null|
| 3| b| c| a| d|
+---+----+----+----+----+
scala> val Columns = Seq("A", "B", "F", "G", "H")
scala> val newCol = Columns filterNot df.columns.toSeq.contains
scala> val df1 = newCol.foldLeft(df)((df,name) => df.withColumn(name, lit(null)))
scala> df1.show()
+---+----+----+----+----+----+----+----+
| id| A| B| C| D| F| G| H|
+---+----+----+----+----+----+----+----+
| 1|toto|tata|titi|null|null|null|null|
| 2| bla| blo|null|null|null|null|null|
| 3| b| c| a| d|null|null|null|
+---+----+----+----+----+----+----+----+
如果您不想使用foldLeft,则可以使用RunTimeMirror,它将更快。检查下面的代码。
scala> import scala.reflect.runtime.universe.runtimeMirror
scala> import scala.tools.reflect.ToolBox
scala> import org.apache.spark.sql.DataFrame
scala> df.show
+---+----+----+----+----+
| id| A| B| C| D|
+---+----+----+----+----+
| 1|toto|tata|titi|null|
| 2| bla| blo|null|null|
| 3| b| c| a| d|
+---+----+----+----+----+
scala> def compile[A](code: String): DataFrame => A = {
| val tb = runtimeMirror(getClass.getClassLoader).mkToolBox()
| val tree = tb.parse(
| s"""
| |import org.elasticsearch.spark.sql._
| |import org.apache.spark.sql.DataFrame
| |def wrapper(context:DataFrame): Any = {
| | $code
| |}
| |wrapper _
| """.stripMargin)
|
| val fun = tb.compile(tree)
| val wrapper = fun()
| wrapper.asInstanceOf[DataFrame => A]
| }
scala> def AddColumns(df:DataFrame,withColumnsString:String):DataFrame = {
| val code =
| s"""
| |import org.apache.spark.sql.functions._
| |import org.elasticsearch.spark.sql._
| |import org.apache.spark.sql.DataFrame
| |var data = context.asInstanceOf[DataFrame]
| |data = data
| """ + withColumnsString +
| """
| |
| |data
| """.stripMargin
|
| val fun = compile[DataFrame](code)
| val res = fun(df)
| res
| }
scala> val Columns = Seq("A", "B", "F", "G", "H")
scala> val newCol = Columns filterNot df.columns.toSeq.contains
scala> var cols = ""
scala> newCol.foreach{ name =>
| cols = ".withColumn(\""+ name + "\" , lit(null))" + cols
| }
scala> val df1 = AddColumns(df,cols)
scala> df1.show
+---+----+----+----+----+----+----+----+
| id| A| B| C| D| H| G| F|
+---+----+----+----+----+----+----+----+
| 1|toto|tata|titi|null|null|null|null|
| 2| bla| blo|null|null|null|null|null|
| 3| b| c| a| d|null|null|null|
+---+----+----+----+----+----+----+----+