使用Spark / Scala,有没有一种方法可以连接复杂的数据结构?

问题描述 投票:0回答:1

我有一个结构复杂的数据框。在该结构内部,我需要根据另一个数据帧中的映射,用另一个值替换一个值。当前,我们通过分解数据帧,加入并通过聚合进行分组来实现此目的。问题是我们要在210B记录中转换3.5B记录。分组的成本非常高。我从已经按需要分组的数据开始。有没有某种方法可以完成此任务而不进行爆炸和分组?

以下是Zeppelin笔记本中的一些示例代码来说明我们当前的方法:

import spark.implicits._

case class A(device_id: Long, cluster: Seq[B])
case class B(location_id: Long, score: Double)
case class C(location_id: Long, location_key: String)
case class D(location_key: String, score: Double)

val df1 = Seq(
  A(1L, Seq(B(1L, 1.1), B(2L, 2.2), B(3L, 3.3))),
  A(2L, Seq(B(4L, 4.4), B(5L, 5.5), B(6L, 6.6))),
  A(3L, Seq(B(7L, 7.7), B(8L, 8.8), B(9L, 9.9)))
).toDF

val df2 = Seq(
  C(1L, "a"),
  C(2L, "b"),
  C(3L, "c"),
  C(4L, "d"),
  C(5L, "e"),
  C(6L, "f"),
  C(7L, "g"),
  C(8L, "h"),
  C(9L, "i")
).toDF

val df3 = df1
  .select($"device_id", explode($"cluster").as("record"))
  .select($"device_id", $"record.location_id".as("location_id"), $"record.score".as("score"))

val df4 = df3
  .join(df2, "location_id")
  .select($"device_id", $"location_key", $"score")

val df5 = df4
  .groupBy($"device_id")
  .agg(
    collect_list(struct($"location_key", $"score")).as("cluster")
  )

df1.printSchema()
df1.show(3, false)

df5.printSchema()
df5.show(3, false)

输出看起来像这样:

root
 |-- device_id: long (nullable = false)
 |-- cluster: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- location_id: long (nullable = false)
 |    |    |-- score: double (nullable = false)

+---------+------------------------------+
|device_id|cluster                       |
+---------+------------------------------+
|1        |[[1, 1.1], [2, 2.2], [3, 3.3]]|
|2        |[[4, 4.4], [5, 5.5], [6, 6.6]]|
|3        |[[7, 7.7], [8, 8.8], [9, 9.9]]|
+---------+------------------------------+

root
 |-- device_id: long (nullable = false)
 |-- cluster: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- location_key: string (nullable = true)
 |    |    |-- score: double (nullable = true)

+---------+------------------------------+
|device_id|cluster                       |
+---------+------------------------------+
|1        |[[a, 1.1], [c, 3.3], [b, 2.2]]|
|2        |[[e, 5.5], [d, 4.4], [f, 6.6]]|
|3        |[[g, 7.7], [h, 8.8], [i, 9.9]]|
+---------+------------------------------+
scala apache-spark join
1个回答
0
投票

我不确定这对性能问题有多大帮助,更不用说回答您的问题了。但是可以将爆炸步骤更改为使用flatMap,这可能会稍微提高您的性能(尽管我认为explode和flatMap在幕后应该非常相似,因此性能差异也可能不存在):

import org.apache.spark.sql.SparkSession                                                                                                  
import org.apache.spark.sql.functions.{ collect_list, struct}                                                                             

case class A(device_id: Long, cluster: Seq[B])                                                                                            
case class B(location_id: Long, score: Double)                                                                                            
case class C(location_id: Long, location_key: String)                                                                                     
case class D(location_key: String, score: Double)                                                                                         
case class E(device_id: Long, location_id: Long, score: Double)                                                                           

object ComplexDataStructures {                                                                                                            
  def main(args: Array[String]): Unit = {                                                                                                 

    val spark = SparkSession                                                                                                              
      .builder()                                                                                                                          
      .appName("Spark SQL basic example")                                                                                                 
      .config("spark.master", "local")                                                                                                    
      .getOrCreate()                                                                                                                      

    import spark.implicits._                                                                                                              

    val df1 = Seq(                                                                                                                        
      A(1L, Seq(B(1L, 1.1), B(2L, 2.2), B(3L, 3.3))),                                                                                     
      A(2L, Seq(B(4L, 4.4), B(5L, 5.5), B(6L, 6.6))),                                                                                     
      A(3L, Seq(B(7L, 7.7), B(8L, 8.8), B(9L, 9.9)))                                                                                      
    ).toDS.as[A]                                                                                                                          

    val df2 = Seq(                                                                                                                        
      C(1L, "a"),                                                                                                                         
      C(2L, "b"),                                                                                                                         
      C(3L, "c"),                                                                                                                         
      C(4L, "d"),                                                                                                                         
      C(5L, "e"),                                                                                                                         
      C(6L, "f"),                                                                                                                         
      C(7L, "g"),                                                                                                                         
      C(8L, "h"),                                                                                                                         
      C(9L, "i")                                                                                                                          
    ).toDS.as[C]                                                                                                                          


    val df3 = df1.flatMap{ case A(a,b) => b.map((a,_)).map( x => E(x._1, x._2.location_id, x._2.score)) }                                 

    val df4 = df3.join(df2, df3("location_id") === df2("location_id")).select(df3("device_id"), df2("location_key"), df3("score"))        

    val df5 = df4.groupBy("device_id").agg(collect_list(struct("location_key","score")).as("cluster"))                                    

    df5.printSchema()                                                                                                                     
    df5.show(3, false)                                                                                                                    

  }                                                                                                                                       
}      
© www.soinside.com 2019 - 2024. All rights reserved.