如何使用CrossValidator在不同的模型之间进行选择

问题描述 投票:2回答:1

我知道我可以使用CrossValidator来调整单个模型。但是,对于相互评估不同模型的建议方法是什么?例如,假设我想使用LogisticRegression来评估针对LinearSVC分类器的CrossValidator分类器。

scala apache-spark apache-spark-mllib cross-validation
1个回答
3
投票

在熟悉了API之后,我通过实现一个自定义的Estimator来解决这个问题,该自定义Param[Int]包含它可以委派的两个或更多个估算器,其中所选的估计器由单个import org.apache.spark.ml.Estimator import org.apache.spark.ml.Model import org.apache.spark.ml.param.Param import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.Params import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType trait DelegatingEstimatorModelParams extends Params { final val selectedEstimator = new Param[Int](this, "selectedEstimator", "The selected estimator") } class DelegatingEstimator private (override val uid: String, delegates: Array[Estimator[_]]) extends Estimator[DelegatingEstimatorModel] with DelegatingEstimatorModelParams { private def this(estimators: Array[Estimator[_]]) = this(Identifiable.randomUID("delegating-estimator"), estimators) def this(estimator1: Estimator[_], estimator2: Estimator[_], estimators: Estimator[_]*) = { this((Seq(estimator1, estimator2) ++ estimators).toArray) } setDefault(selectedEstimator -> 0) override def fit(dataset: Dataset[_]): DelegatingEstimatorModel = { val estimator = delegates(getOrDefault(selectedEstimator)) val model = estimator.fit(dataset).asInstanceOf[Model[_]] new DelegatingEstimatorModel(uid, model) } override def copy(extra: ParamMap): Estimator[DelegatingEstimatorModel] = { val that = new DelegatingEstimator(uid, delegates) copyValues(that, extra) } override def transformSchema(schema: StructType): StructType = { // All delegates are assumed to perform the same schema transformation, // so we can simply select the first one: delegates(0).transformSchema(schema) } } class DelegatingEstimatorModel(override val uid: String, val delegate: Model[_]) extends Model[DelegatingEstimatorModel] with DelegatingEstimatorModelParams { def copy(extra: ParamMap): DelegatingEstimatorModel = new DelegatingEstimatorModel(uid, delegate.copy(extra).asInstanceOf[Model[_]]) def transform(dataset: Dataset[_]): DataFrame = delegate.transform(dataset) def transformSchema(schema: StructType): StructType = delegate.transformSchema(schema) } 控制。这是实际的代码:

LogistcRegression

评估一个LinearSVCval logRegression = new LogisticRegression() .setFeaturesCol(columnNames.features) .setPredictionCol(columnNames.prediction) .setRawPredictionCol(columnNames.rawPrediciton) .setLabelCol(columnNames.label) val svmEstimator = new LinearSVC() .setFeaturesCol(columnNames.features) .setPredictionCol(columnNames.prediction) .setRawPredictionCol(columnNames.rawPrediciton) .setLabelCol(columnNames.label) val delegatingEstimator = new DelegatingEstimator(logRegression, svmEstimator) val paramGrid = new ParamGridBuilder() .addGrid(delegatingEstimator.selectedEstimator, Array(0, 1)) .build() val model = crossValidator.fit(data) val bestModel = model.bestModel.asInstanceOf[DelegatingEstimatorModel].delegate 上面的类可以这样使用:

qazxswpoi
© www.soinside.com 2019 - 2024. All rights reserved.