[XGBoost每个工人集成一个XGBoost模型

问题描述 投票:1回答:1

尝试使用此笔记本https://databricks-prod-cloudfront.cloud.databricks.com/public/4027ec902e239c93eaaa8714f173bcfc/1526931011080774/3624187670661048/6320440561800420/latest.html

使用spark版本2.4.3和xgboost 0.90

[试图执行时保持此错误ValueError: bad input shape () ...

features = inputTrainingDF.select("features").collect()
lables = inputTrainingDF.select("label").collect()

X = np.asarray(map(lambda v: v[0].toArray(), features))
Y = np.asarray(map(lambda v: v[0], lables))

xgbClassifier = xgb.XGBClassifier(max_depth=3, seed=18238, objective='binary:logistic')

model = xgbClassifier.fit(X, Y)
ValueError: bad input shape () 

def trainXGbModel(partitionKey, labelAndFeatures):
  X = np.asarray(map(lambda v: v[1].toArray(), labelAndFeatures))
  Y = np.asarray(map(lambda v: v[0], labelAndFeatures))
  xgbClassifier = xgb.XGBClassifier(max_depth=3, seed=18238, objective='binary:logistic' )
  model =  xgbClassifier.fit(X, Y)
  return [partitionKey, model]

xgbModels = inputTrainingDF\
.select("education", "label", "features")\
.rdd\
.map(lambda row: [row[0], [row[1], row[2]]])\
.groupByKey()\
.map(lambda v: trainXGbModel(v[0], list(v[1])))

xgbModels.take(1)
ValueError: bad input shape ()

您可以在笔记本中看到它对发布者有效。我的猜测是它与XY np.asarray()映射有关,因为逻辑只是试图将标签和要素映射到函数,但是形状为空。使用此代码即可正常工作]

pandasDF = inputTrainingDF.toPandas()
series = pandasDF['features'].apply(lambda x : np.array(x.toArray())).as_matrix().reshape(-1,1)
features = np.apply_along_axis(lambda x : x[0], 1, series)
target = pandasDF['label'].values
xgbClassifier = xgb.XGBClassifier(max_depth=3, seed=18238, objective='binary:logistic' )
model = xgbClassifier.fit(features, target)

但是想集成到原始功能调用中并了解为什么原始笔记本无法正常工作。非常感谢您解决此问题!

apache-spark pyspark apache-spark-mllib xgboost apache-spark-ml
1个回答
0
投票

您可能正在使用python3。问题是在python3 map函数中返回的是迭代器对象,而不是集合。解决此问题所需要做的就是更改map-> list(map(...))

def trainXGbModel(partitionKey, labelAndFeatures):
  X = np.asarray(list(map(lambda v: v[1].toArray(), labelAndFeatures)))
  Y = np.asarray(list(map(lambda v: v[0], labelAndFeatures)))

或者您可以使用np.fromiter将可迭代对象转换为numpy数组。

© www.soinside.com 2019 - 2024. All rights reserved.