Spark:FlatMap和CountVectorizer管道

问题描述 投票:2回答:1

我正在处理管道,并尝试在将列值传递给CountVectorizer之前对其进行拆分。

为此,我制作了一个自定义的变压器。

class FlatMapTransformer(override val uid: String)
  extends Transformer {
  /**
   * Param for input column name.
   * @group param
   */
  final val inputCol = new Param[String](this, "inputCol", "The input column")
  final def getInputCol: String = $(inputCol)

  /**
   * Param for output column name.
   * @group param
   */
  final val outputCol = new Param[String](this, "outputCol", "The output column")
  final def getOutputCol: String = $(outputCol)

  def setInputCol(value: String): this.type = set(inputCol, value)
  def setOutputCol(value: String): this.type = set(outputCol, value)

  def this() = this(Identifiable.randomUID("FlatMapTransformer"))

  private val flatMap: String => Seq[String] = { input: String =>
    input.split(",")
  }

  override def copy(extra: ParamMap): SplitString = defaultCopy(extra)

  override def transform(dataset: Dataset[_]): DataFrame = {
    val flatMapUdf = udf(flatMap)
    dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol)))))
  }

  override def transformSchema(schema: StructType): StructType = {
    val dataType = schema($(inputCol)).dataType
    require(
      dataType.isInstanceOf[StringType],
      s"Input column must be of type StringType but got ${dataType}")
    val inputFields = schema.fields
    require(
      !inputFields.exists(_.name == $(outputCol)),
      s"Output column ${$(outputCol)} already exists.")

    DataTypes.createStructType(
      Array(
        DataTypes.createStructField($(outputCol), DataTypes.StringType, false)))
  }
}

该代码似乎合法,但是当我尝试将其与其他操作链接时,就会出现问题。这是我的管道:

val train = reader.readTrainingData()

val cat_features = getFeaturesByType(taskConfig, "categorical")
val num_features = getFeaturesByType(taskConfig, "numeric")
val cat_ohe_features = getFeaturesByType(taskConfig, "categorical", Some("ohe"))
val cat_features_string_index = cat_features.
  filter { feature: String => !cat_ohe_features.contains(feature) }

val catIndexer = cat_features_string_index.map {
  feature =>
    new StringIndexer()
      .setInputCol(feature)
      .setOutputCol(feature + "_index")
      .setHandleInvalid("keep")
}

    val flatMapper = cat_ohe_features.map {
      feature =>
        new FlatMapTransformer()
          .setInputCol(feature)
          .setOutputCol(feature + "_transformed")
    }

    val countVectorizer = cat_ohe_features.map {
      feature =>

        new CountVectorizer()
          .setInputCol(feature + "_transformed")
          .setOutputCol(feature + "_vectorized")
          .setVocabSize(10)
    }


// val countVectorizer = cat_ohe_features.map {
//   feature =>
//
//     val flatMapper = new FlatMapTransformer()
//       .setInputCol(feature)
//       .setOutputCol(feature + "_transformed")
// 
//     new CountVectorizer()
//       .setInputCol(flatMapper.getOutputCol)
//       .setOutputCol(feature + "_vectorized")
//       .setVocabSize(10)
// }

val cat_features_index = cat_features_string_index.map {
  (feature: String) => feature + "_index"
}

val count_vectorized_index = cat_ohe_features.map {
  (feature: String) => feature + "_vectorized"
}

val catFeatureAssembler = new VectorAssembler()
  .setInputCols(cat_features_index)
  .setOutputCol("cat_features")

val oheFeatureAssembler = new VectorAssembler()
  .setInputCols(count_vectorized_index)
  .setOutputCol("cat_ohe_features")

val numFeatureAssembler = new VectorAssembler()
  .setInputCols(num_features)
  .setOutputCol("num_features")

val featureAssembler = new VectorAssembler()
  .setInputCols(Array("cat_features", "num_features", "cat_ohe_features_vectorized"))
  .setOutputCol("features")

val pipelineStages = catIndexer ++ flatMapper ++ countVectorizer ++
  Array(
    catFeatureAssembler,
    oheFeatureAssembler,
    numFeatureAssembler,
    featureAssembler)

val pipeline = new Pipeline().setStages(pipelineStages)
pipeline.fit(dataset = train)

运行此代码,我收到一个错误:java.lang.IllegalArgumentException: Field "my_ohe_field_trasformed" does not exist.

[info]  java.lang.IllegalArgumentException: Field "from_expdelv_areas_transformed" does not exist.

[info]  at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)
[info]  at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)

[info]  at scala.collection.MapLike$class.getOrElse(MapLike.scala:128)

[info]  at scala.collection.AbstractMap.getOrElse(Map.scala:59)

[info]  at org.apache.spark.sql.types.StructType.apply(StructType.scala:265)

[info]  at org.apache.spark.ml.util.SchemaUtils$.checkColumnTypes(SchemaUtils.scala:56)

[info]  at org.apache.spark.ml.feature.CountVectorizerParams$class.validateAndTransformSchema(CountVectorizer.scala:75)

[info]  at org.apache.spark.ml.feature.CountVectorizer.validateAndTransformSchema(CountVectorizer.scala:123)

[info]  at org.apache.spark.ml.feature.CountVectorizer.transformSchema(CountVectorizer.scala:188)

当我取消注释stringSplittercountVectorizer时,在我的Transformer中出现了错误

java.lang.IllegalArgumentException: Field "my_ohe_field" does not exist. atval dataType = schema($(inputCol)).dataType

调用pipeline.getStages的结果:

strIdx_3c2630a738f0

strIdx_0d76d55d4200

FlatMapTransformer_fd8595c2969c

FlatMapTransformer_2e9a7af0b0fa

cntVec_c2ef31f00181

cntVec_68a78eca06c9

vecAssembler_a81dd9f43d56

vecAssembler_b647d348f0a0

vecAssembler_b5065a22d5c8

vecAssembler_d9176b8bb593

我可能遵循错误的方式。任何意见表示赞赏。

scala apache-spark apache-spark-mllib countvectorizer
1个回答
2
投票

您的FlatMapTransformer #transform不正确,仅在outputCol上选择时,您会丢弃/忽略所有其他列的类型>

请将您的方法修改为-
 override def transform(dataset: Dataset[_]): DataFrame = {
     val flatMapUdf = udf(flatMap)
    dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol)))))
  }

此外,修改transformSchema以先检查输入列,然后再检查其数据类型-

 override def transformSchema(schema: StructType): StructType = {
require(schema.names.contains($(inputCol)), "inputCOl is not there in the input dataframe")
//... rest as it is
}

基于注释的Update-1

  1. 请修改copy方法(尽管这不是您遇到异常的原因)-
  2. override def copy(extra: ParamMap): FlatMapTransformer = defaultCopy(extra)
    
  1. 请注意,CountVectorizer接受具有ArrayType(StringType, true/false)类型的列,并且由于FlatMapTransformer输出列成为CountVectorizer的输入,因此您需要确保FlatMapTransformer的输出列必须为[ C0]。我认为并非如此,您今天的代码如下-
  2. ArrayType(StringType, true/false)

override def transform(dataset: Dataset[_]): DataFrame = { val flatMapUdf = udf(flatMap) dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol))))) } 功能将explode转换为array<string>,因此变压器的输出变为string。您可能需要将此代码更改为-

StringType
  1. 修改 override def transform(dataset: Dataset[_]): DataFrame = { val flatMapUdf = udf(flatMap) dataset.withColumn($(outputCol), flatMapUdf(col($(inputCol)))) } 方法以输出transformSchema
ArrayType(StringType)
  1. 将向量汇编器更改为此-
 override def transformSchema(schema: StructType): StructType = {
      val dataType = schema($(inputCol)).dataType
      require(
        dataType.isInstanceOf[StringType],
        s"Input column must be of type StringType but got ${dataType}")
      val inputFields = schema.fields
      require(
        !inputFields.exists(_.name == $(outputCol)),
        s"Output column ${$(outputCol)} already exists.")

      schema.add($(outputCol), ArrayType(StringType))
    }

我试图在虚拟数据帧上执行您的管道,效果很好。请参考val featureAssembler = new VectorAssembler() .setInputCols(Array("cat_features", "num_features", "cat_ohe_features")) .setOutputCol("features") 以获取完整代码。

© www.soinside.com 2019 - 2024. All rights reserved.