尝试在GBT上进行交叉验证时遇到以下错误消息。以前运行GBT模型没有问题。
是否不再支持fitMultiple?我正在使用PySpark 2.4.4。
这是我的代码:
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
paramGrid = ParamGridBuilder().addGrid(gbt.maxDepth, [2, 4, 6]).addGrid(gbt.maxBins, [20, 60]).addGrid(gbt.maxIter, [10, 20]).build()
evaluator = BinaryClassificationEvaluator()
cv = CrossValidator(estimator=gbt, estimatorParamMaps=paramGrid, evaluator=evaluator)
# Run cross validations. This can take about 6 minutes since it is training over 20 trees!
cvModel = cv.fit(train)
predictions = cvModel.transform(test)
evaluator.evaluate(predictions)
这是错误消息
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-150-95e5a7d270fe> in <module>
6
7 # Run cross validations. This can take about 6 minutes since it is training over 20 trees!
----> 8 cvModel = cv.fit(train)
9 predictions = cvModel.transform(test)
10 evaluator.evaluate(predictions)
/anaconda_env/personal/gleow/py37/lib/python3.7/site-packages/pyspark/ml/base.py in fit(self, dataset, params)
130 return self.copy(params)._fit(dataset)
131 else:
--> 132 return self._fit(dataset)
133 else:
134 raise ValueError("Params must be either a param map or a list/tuple of param maps, "
/anaconda_env/personal/gleow/py37/lib/python3.7/site-packages/pyspark/ml/tuning.py in _fit(self, dataset)
301 train = df.filter(~condition).cache()
302
--> 303 tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)
304 for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
305 metrics[j] += (metric / nFolds)
/anaconda_env/personal/gleow/py37/lib/python3.7/site-packages/pyspark/ml/tuning.py in _parallelFitTasks(est, train, eva, validation, epm, collectSubModel)
47 :return: (int, float, subModel), an index into `epm` and the associated metric value.
48 """
---> 49 modelIter = est.fitMultiple(train, epm)
50
51 def singleTask():
AttributeError: 'GBTClassificationModel' object has no attribute 'fitMultiple'
对我的代码进行故障排除并发现了错误……原来我已经初始化了gbt并适合以下内容:
gbt = GBTClassifier(maxIter=10)
gbt = gbt.fit(train)
...
...
因此尝试运行以下内容:
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
paramGrid = ...
evaluator = ...
cv = CrossValidator(estimator=gbt, estimatorParamMaps=paramGrid, evaluator=evaluator)
CrossValidator实际上正在调用'gbt.fit(train)',它不再是分类器对象。所以我要做的是再次初始化gbt = GBTClassifier(maxIter = 10)。