pyspark:在gridsearch为空后获取最佳模型的参数{}

问题描述 投票:1回答:3

有人可以帮助我从网格搜索中提取表现最佳的模型参数吗?由于某种原因,它是一个空白的字典。

from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit, CrossValidator
from pyspark.ml.evaluation import BinaryClassificationEvaluator


train, test = df.randomSplit([0.66, 0.34], seed=12345)

paramGrid = (ParamGridBuilder()
             .addGrid(lr.regParam, [0.01,0.1])
             .addGrid(lr.elasticNetParam, [1.0,])
             .addGrid(lr.maxIter, [3,])
             .build())

evaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction",labelCol="buy")
evaluator.setMetricName('areaUnderROC')

cv = CrossValidator(estimator=pipeline,
                          estimatorParamMaps=paramGrid,
                          evaluator=evaluator,
                          numFolds=2)  
cvModel = cv.fit(train)

> print(cvModel.bestModel) #it looks like I have a valid bestModel
PipelineModel_406e9483e92ebda90524 In [8]:

> cvModel.bestModel.extractParamMap() #fails
 {} In [9]:

> cvModel.bestModel.getRegParam() #also fails
> 
> AttributeError                            Traceback (most recent call
> last) <ipython-input-9-747196173391> in <module>()
> ----> 1 cvModel.bestModel.getRegParam()
> 
> AttributeError: 'PipelineModel' object has no attribute 'getRegParam'
python apache-spark pyspark apache-spark-ml grid-search
3个回答
4
投票

这里有两个不同的问题:

  • 参数设置在单独的EstiamtorsTransformers而不是PipelineModel。可以使用stages属性访问所有模型。
  • 在Spark 2.3之前,Python模型根本不包含ParamsSPARK-10931)。

因此,除非您使用开发分支,否则您必须在分支机构access its _java_obj and get parameters of interest中找到感兴趣的模型。例如:

from pyspark.ml.classification import LogisticRegressionModel

[x._java_obj.getRegParam() 
for x in cvModel.bestModel.stages if isinstance(x, LogisticRegressionModel)]

0
投票

试试这个:

cvModel.bestModel.stages[-1].extractParamMap()

你可以用你喜欢的任何数字改变-1。


0
投票

我最近遇到过这个问题,最适合我的解决方案是从extractParamMap创建一个关键名称及其值的字典,然后使用它来获取我想要的名称值。

best_mod = cvModel.bestModel
param_dict = best_mod.stages[-1].extractParamMap()

sane_dict = {}
for k, v in param_dict.items():
  sane_dict[k.name] = v

best_reg = sane_dict["regParam"]
best_elastic_net = sane_dict["elasticNetParam"]
best_max_iter = sane_dict["maxIter"]

希望这可以帮助!

© www.soinside.com 2019 - 2024. All rights reserved.