对类型化数据集进行模式匹配

问题描述 投票:0回答:2

我正在尝试根据spark数据集的类型应用不同类型的逻辑。根据传递给doWork的案例类的类型(CustomerWorker),我必须应用不同类型的聚合。我该怎么办?

import org.apache.spark.sql.{Dataset, SparkSession}

object SparkSql extends App {
  import spark.implicits._

  val spark = SparkSession
    .builder()
    .appName("Simple app")
    .config("spark.master", "local")
    .getOrCreate()

  sealed trait Person {
    def name: String
  }

  final case class Customer(override val name: String, email: String)                extends Person
  final case class Worker(override val name: String, id: Int, skills: Array[String]) extends Person

  val workers: Dataset[Worker] = Seq(
    Worker("Bob", id = 1, skills = Array("communication", "teamwork")),
    Worker("Sam", id = 1, skills = Array("self-motivation"))
  ).toDS

  def doWork(persons: Dataset[Person]): Unit = {
    persons match {
      case ... // Dataset[Customer] ... do something
      case ... // Dataset[Worker] ... do something else
    }
  }

}
scala apache-spark dataset pattern-matching
2个回答
0
投票

使用案例类,您可以进行模式匹配。案例类是Scala允许对对象进行模式匹配而无需需要大量样板。通常,您需要做的就是添加您要与每个类别进行模式匹配的单个大小写关键字。

例如:

abstract class Expr
case class Var(name: String) extends Expr
case class Number(num: Double) extends Expr
case class UnOp(operator: String, arg: Expr) extends Expr
case class BinOp(operator: String,left: Expr, right: Expr) extends Expr

def simplifyTop(expr: Expr): Expr = expr match {
  case UnOp("",UnOp("",e)) => e // Double negation
  case BinOp("+", e, Number(0)) => e // Adding zero
  case BinOp("*", e, Number(1)) => e // Multiplying by one
  case _ => expr
}

以您的示例为例,我会尝试

  def doWork(persons: Person): Unit = {
    persons match {
      case Customer => ... do something
      case Worker ... do something else
    }
  }

dataset.map(doWork)

0
投票

修改您的方法以接受[T <:parent],并从Dataset.javaRdd中提取豆类名称,如下所示

import org.apache.spark.sql.Dataset

object InheritDataframe {


  private def matcherDef[T <:parent](dfb: Dataset[T]): Unit = {

    dfb.toJavaRDD.classTag.toString() match {
      case "child1" =>  println("child1")
      case "child2" => println("child2")
      case _ => println("Unkown")
    }

  }

  def main(args: Array[String]): Unit = {

    val spark = Constant.getSparkSess

    import spark.implicits._

    val dfB  = List(child1(1)).toDS()
    val dfC  = List(child2(1)).toDS()

    matcherDef(dfB)
    matcherDef(dfC)

  }


}

case class child1(i: Int) extends parent(i)

case class child2(i: Int) extends parent(i)

class parent(j: Int)


© www.soinside.com 2019 - 2024. All rights reserved.