我有一个结构复杂的数据框。在该结构内部,我需要根据另一个数据帧中的映射,用另一个值替换一个值。当前,我们通过分解数据帧,加入并通过聚合进行分组来实现此目的。问题是我们要在210B记录中转换3.5B记录。分组的成本非常高。我从已经按需要分组的数据开始。有没有某种方法可以完成此任务而不进行爆炸和分组?
以下是Zeppelin笔记本中的一些示例代码来说明我们当前的方法:
import spark.implicits._
case class A(device_id: Long, cluster: Seq[B])
case class B(location_id: Long, score: Double)
case class C(location_id: Long, location_key: String)
case class D(location_key: String, score: Double)
val df1 = Seq(
A(1L, Seq(B(1L, 1.1), B(2L, 2.2), B(3L, 3.3))),
A(2L, Seq(B(4L, 4.4), B(5L, 5.5), B(6L, 6.6))),
A(3L, Seq(B(7L, 7.7), B(8L, 8.8), B(9L, 9.9)))
).toDF
val df2 = Seq(
C(1L, "a"),
C(2L, "b"),
C(3L, "c"),
C(4L, "d"),
C(5L, "e"),
C(6L, "f"),
C(7L, "g"),
C(8L, "h"),
C(9L, "i")
).toDF
val df3 = df1
.select($"device_id", explode($"cluster").as("record"))
.select($"device_id", $"record.location_id".as("location_id"), $"record.score".as("score"))
val df4 = df3
.join(df2, "location_id")
.select($"device_id", $"location_key", $"score")
val df5 = df4
.groupBy($"device_id")
.agg(
collect_list(struct($"location_key", $"score")).as("cluster")
)
df1.printSchema()
df1.show(3, false)
df5.printSchema()
df5.show(3, false)
输出看起来像这样:
root
|-- device_id: long (nullable = false)
|-- cluster: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- location_id: long (nullable = false)
| | |-- score: double (nullable = false)
+---------+------------------------------+
|device_id|cluster |
+---------+------------------------------+
|1 |[[1, 1.1], [2, 2.2], [3, 3.3]]|
|2 |[[4, 4.4], [5, 5.5], [6, 6.6]]|
|3 |[[7, 7.7], [8, 8.8], [9, 9.9]]|
+---------+------------------------------+
root
|-- device_id: long (nullable = false)
|-- cluster: array (nullable = true)
| |-- element: struct (containsNull = true)
| | |-- location_key: string (nullable = true)
| | |-- score: double (nullable = true)
+---------+------------------------------+
|device_id|cluster |
+---------+------------------------------+
|1 |[[a, 1.1], [c, 3.3], [b, 2.2]]|
|2 |[[e, 5.5], [d, 4.4], [f, 6.6]]|
|3 |[[g, 7.7], [h, 8.8], [i, 9.9]]|
+---------+------------------------------+
我不确定这对性能问题有多大帮助,更不用说回答您的问题了。但是可以将爆炸步骤更改为使用flatMap,这可能会稍微提高您的性能(尽管我认为explode和flatMap在幕后应该非常相似,因此性能差异也可能不存在):
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{ collect_list, struct}
case class A(device_id: Long, cluster: Seq[B])
case class B(location_id: Long, score: Double)
case class C(location_id: Long, location_key: String)
case class D(location_key: String, score: Double)
case class E(device_id: Long, location_id: Long, score: Double)
object ComplexDataStructures {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName("Spark SQL basic example")
.config("spark.master", "local")
.getOrCreate()
import spark.implicits._
val df1 = Seq(
A(1L, Seq(B(1L, 1.1), B(2L, 2.2), B(3L, 3.3))),
A(2L, Seq(B(4L, 4.4), B(5L, 5.5), B(6L, 6.6))),
A(3L, Seq(B(7L, 7.7), B(8L, 8.8), B(9L, 9.9)))
).toDS.as[A]
val df2 = Seq(
C(1L, "a"),
C(2L, "b"),
C(3L, "c"),
C(4L, "d"),
C(5L, "e"),
C(6L, "f"),
C(7L, "g"),
C(8L, "h"),
C(9L, "i")
).toDS.as[C]
val df3 = df1.flatMap{ case A(a,b) => b.map((a,_)).map( x => E(x._1, x._2.location_id, x._2.score)) }
val df4 = df3.join(df2, df3("location_id") === df2("location_id")).select(df3("device_id"), df2("location_key"), df3("score"))
val df5 = df4.groupBy("device_id").agg(collect_list(struct("location_key","score")).as("cluster"))
df5.printSchema()
df5.show(3, false)
}
}