我正在处理管道,并尝试在将列值传递给CountVectorizer
之前对其进行拆分。
为此,我制作了一个自定义的变压器。
class FlatMapTransformer(override val uid: String)
extends Transformer {
/**
* Param for input column name.
* @group param
*/
final val inputCol = new Param[String](this, "inputCol", "The input column")
final def getInputCol: String = $(inputCol)
/**
* Param for output column name.
* @group param
*/
final val outputCol = new Param[String](this, "outputCol", "The output column")
final def getOutputCol: String = $(outputCol)
def setInputCol(value: String): this.type = set(inputCol, value)
def setOutputCol(value: String): this.type = set(outputCol, value)
def this() = this(Identifiable.randomUID("FlatMapTransformer"))
private val flatMap: String => Seq[String] = { input: String =>
input.split(",")
}
override def copy(extra: ParamMap): SplitString = defaultCopy(extra)
override def transform(dataset: Dataset[_]): DataFrame = {
val flatMapUdf = udf(flatMap)
dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol)))))
}
override def transformSchema(schema: StructType): StructType = {
val dataType = schema($(inputCol)).dataType
require(
dataType.isInstanceOf[StringType],
s"Input column must be of type StringType but got ${dataType}")
val inputFields = schema.fields
require(
!inputFields.exists(_.name == $(outputCol)),
s"Output column ${$(outputCol)} already exists.")
DataTypes.createStructType(
Array(
DataTypes.createStructField($(outputCol), DataTypes.StringType, false)))
}
}
该代码似乎合法,但是当我尝试将其与其他操作链接时,就会出现问题。这是我的管道:
val train = reader.readTrainingData()
val cat_features = getFeaturesByType(taskConfig, "categorical")
val num_features = getFeaturesByType(taskConfig, "numeric")
val cat_ohe_features = getFeaturesByType(taskConfig, "categorical", Some("ohe"))
val cat_features_string_index = cat_features.
filter { feature: String => !cat_ohe_features.contains(feature) }
val catIndexer = cat_features_string_index.map {
feature =>
new StringIndexer()
.setInputCol(feature)
.setOutputCol(feature + "_index")
.setHandleInvalid("keep")
}
val flatMapper = cat_ohe_features.map {
feature =>
new FlatMapTransformer()
.setInputCol(feature)
.setOutputCol(feature + "_transformed")
}
val countVectorizer = cat_ohe_features.map {
feature =>
new CountVectorizer()
.setInputCol(feature + "_transformed")
.setOutputCol(feature + "_vectorized")
.setVocabSize(10)
}
// val countVectorizer = cat_ohe_features.map {
// feature =>
//
// val flatMapper = new FlatMapTransformer()
// .setInputCol(feature)
// .setOutputCol(feature + "_transformed")
//
// new CountVectorizer()
// .setInputCol(flatMapper.getOutputCol)
// .setOutputCol(feature + "_vectorized")
// .setVocabSize(10)
// }
val cat_features_index = cat_features_string_index.map {
(feature: String) => feature + "_index"
}
val count_vectorized_index = cat_ohe_features.map {
(feature: String) => feature + "_vectorized"
}
val catFeatureAssembler = new VectorAssembler()
.setInputCols(cat_features_index)
.setOutputCol("cat_features")
val oheFeatureAssembler = new VectorAssembler()
.setInputCols(count_vectorized_index)
.setOutputCol("cat_ohe_features")
val numFeatureAssembler = new VectorAssembler()
.setInputCols(num_features)
.setOutputCol("num_features")
val featureAssembler = new VectorAssembler()
.setInputCols(Array("cat_features", "num_features", "cat_ohe_features_vectorized"))
.setOutputCol("features")
val pipelineStages = catIndexer ++ flatMapper ++ countVectorizer ++
Array(
catFeatureAssembler,
oheFeatureAssembler,
numFeatureAssembler,
featureAssembler)
val pipeline = new Pipeline().setStages(pipelineStages)
pipeline.fit(dataset = train)
运行此代码,我收到一个错误:java.lang.IllegalArgumentException: Field "my_ohe_field_trasformed" does not exist.
[info] java.lang.IllegalArgumentException: Field "from_expdelv_areas_transformed" does not exist.
[info] at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)
[info] at org.apache.spark.sql.types.StructType$$anonfun$apply$1.apply(StructType.scala:266)
[info] at scala.collection.MapLike$class.getOrElse(MapLike.scala:128)
[info] at scala.collection.AbstractMap.getOrElse(Map.scala:59)
[info] at org.apache.spark.sql.types.StructType.apply(StructType.scala:265)
[info] at org.apache.spark.ml.util.SchemaUtils$.checkColumnTypes(SchemaUtils.scala:56)
[info] at org.apache.spark.ml.feature.CountVectorizerParams$class.validateAndTransformSchema(CountVectorizer.scala:75)
[info] at org.apache.spark.ml.feature.CountVectorizer.validateAndTransformSchema(CountVectorizer.scala:123)
[info] at org.apache.spark.ml.feature.CountVectorizer.transformSchema(CountVectorizer.scala:188)
当我取消注释stringSplitter
和countVectorizer
时,在我的Transformer中出现了错误
java.lang.IllegalArgumentException: Field "my_ohe_field" does not exist.
atval dataType = schema($(inputCol)).dataType
调用pipeline.getStages
的结果:
strIdx_3c2630a738f0
strIdx_0d76d55d4200
FlatMapTransformer_fd8595c2969c
FlatMapTransformer_2e9a7af0b0fa
cntVec_c2ef31f00181
cntVec_68a78eca06c9
vecAssembler_a81dd9f43d56
vecAssembler_b647d348f0a0
vecAssembler_b5065a22d5c8
vecAssembler_d9176b8bb593
我可能遵循错误的方式。任何意见表示赞赏。
您的FlatMapTransformer #transform
不正确,仅在outputCol
上选择时,您会丢弃/忽略所有其他列的类型>
override def transform(dataset: Dataset[_]): DataFrame = {
val flatMapUdf = udf(flatMap)
dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol)))))
}
此外,修改transformSchema
以先检查输入列,然后再检查其数据类型-
override def transformSchema(schema: StructType): StructType = { require(schema.names.contains($(inputCol)), "inputCOl is not there in the input dataframe") //... rest as it is }
基于注释的Update-1
copy
方法(尽管这不是您遇到异常的原因)-override def copy(extra: ParamMap): FlatMapTransformer = defaultCopy(extra)
CountVectorizer
接受具有ArrayType(StringType, true/false)
类型的列,并且由于FlatMapTransformer
输出列成为CountVectorizer
的输入,因此您需要确保FlatMapTransformer
的输出列必须为[ C0]。我认为并非如此,您今天的代码如下-ArrayType(StringType, true/false)
override def transform(dataset: Dataset[_]): DataFrame = {
val flatMapUdf = udf(flatMap)
dataset.withColumn($(outputCol), explode(flatMapUdf(col($(inputCol)))))
}
功能将explode
转换为array<string>
,因此变压器的输出变为string
。您可能需要将此代码更改为-
StringType
- 修改
override def transform(dataset: Dataset[_]): DataFrame = { val flatMapUdf = udf(flatMap) dataset.withColumn($(outputCol), flatMapUdf(col($(inputCol)))) }
方法以输出transformSchema
ArrayType(StringType)
- 将向量汇编器更改为此-
override def transformSchema(schema: StructType): StructType = { val dataType = schema($(inputCol)).dataType require( dataType.isInstanceOf[StringType], s"Input column must be of type StringType but got ${dataType}") val inputFields = schema.fields require( !inputFields.exists(_.name == $(outputCol)), s"Output column ${$(outputCol)} already exists.") schema.add($(outputCol), ArrayType(StringType)) }
我试图在虚拟数据帧上执行您的管道,效果很好。请参考
val featureAssembler = new VectorAssembler() .setInputCols(Array("cat_features", "num_features", "cat_ohe_features")) .setOutputCol("features")
以获取完整代码。