'GBTClassificationModel'对象没有属性'fitMultiple'-pyspark

问题描述 投票:1回答:1

尝试在GBT上进行交叉验证时遇到以下错误消息。以前运行GBT模型没有问题。

是否不再支持fitMultiple?我正在使用PySpark 2.4.4。

这是我的代码:

from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

paramGrid = ParamGridBuilder().addGrid(gbt.maxDepth, [2, 4, 6]).addGrid(gbt.maxBins, [20, 60]).addGrid(gbt.maxIter, [10, 20]).build()
evaluator = BinaryClassificationEvaluator()
cv = CrossValidator(estimator=gbt, estimatorParamMaps=paramGrid, evaluator=evaluator)

# Run cross validations.  This can take about 6 minutes since it is training over 20 trees!
cvModel = cv.fit(train)
predictions = cvModel.transform(test)
evaluator.evaluate(predictions)

这是错误消息

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-150-95e5a7d270fe> in <module>
      6 
      7 # Run cross validations.  This can take about 6 minutes since it is training over 20 trees!
----> 8 cvModel = cv.fit(train)
      9 predictions = cvModel.transform(test)
     10 evaluator.evaluate(predictions)

/anaconda_env/personal/gleow/py37/lib/python3.7/site-packages/pyspark/ml/base.py in fit(self, dataset, params)
    130                 return self.copy(params)._fit(dataset)
    131             else:
--> 132                 return self._fit(dataset)
    133         else:
    134             raise ValueError("Params must be either a param map or a list/tuple of param maps, "

/anaconda_env/personal/gleow/py37/lib/python3.7/site-packages/pyspark/ml/tuning.py in _fit(self, dataset)
    301             train = df.filter(~condition).cache()
    302 
--> 303             tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)
    304             for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
    305                 metrics[j] += (metric / nFolds)

/anaconda_env/personal/gleow/py37/lib/python3.7/site-packages/pyspark/ml/tuning.py in _parallelFitTasks(est, train, eva, validation, epm, collectSubModel)
     47     :return: (int, float, subModel), an index into `epm` and the associated metric value.
     48     """
---> 49     modelIter = est.fitMultiple(train, epm)
     50 
     51     def singleTask():

AttributeError: 'GBTClassificationModel' object has no attribute 'fitMultiple'
python machine-learning pyspark apache-spark-ml
1个回答
0
投票

对我的代码进行故障排除并发现了错误……原来我已经初始化了gbt并适合以下内容:

gbt = GBTClassifier(maxIter=10)
gbt = gbt.fit(train)
...
...

因此尝试运行以下内容:

from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

paramGrid = ...
evaluator = ...
cv = CrossValidator(estimator=gbt, estimatorParamMaps=paramGrid, evaluator=evaluator)

CrossValidator实际上正在调用'gbt.fit(train)',它不再是分类器对象。所以我要做的是再次初始化gbt = GBTClassifier(maxIter = 10)。

© www.soinside.com 2019 - 2024. All rights reserved.