Spark ML API将矢量转换为多标签分类的概率

问题描述 投票:-1回答:1

我对Spark ML API有点新意。我正在尝试通过训练160个分类器(后勤或随机森林等)对160个标签进行多标签分类。一旦我在Dataset [LabeledPoint]上训练,我发现很难获得一个API,我得到了每个类的概率。我已经阅读了SO,你可以使用管道API并获得概率,但对于我的用例,这将很难,因为我必须为我的评估功能复制160个RDD,获得每个类的概率然后根据概率对类进行排名。相反,我想只有一个评估功能的副本,广播160个模型,然后在地图功能中进行预测。我发现自己必须实现这一点,但想知道Spark中是否有另一个便利API可以为Logistic / RF等不同的分类器做同样的事情,它将表示要素的Vector转换为属于类的概率。如果有更好的方法来处理Spark中的多标签分类,请告诉我。

编辑:我试图创建一个函数来将矢量转换为随机森林的标签,但它非常烦人,因为我现在必须在Spark中克隆大块树遍历,几乎在任何地方我遇到死角因为某些函数或变量是私人或受保护的。如果错误,请纠正我,但如果这个用例尚未实现,我认为它至少是合理的,因为Scikit-learn已经有了这样的API来实现这一点。

谢谢

scala apache-spark machine-learning apache-spark-ml
1个回答
0
投票

在Spark MLLib代码中找到了罪魁祸首:https://github.com/apache/spark/blob/5ad644a4cefc20e4f198d614c59b8b0f75a228ba/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala#L224

预测方法被标记为受保护但实际上应该是公共的,以支持此类用例。

这已在版本2.4中修复,如下所示:https://github.com/apache/spark/blob/branch-2.4/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala

所以升级到版本2.4应该可以做到这一点......虽然我不认为2.4已经出现了,所以这是等待的问题。

编辑:对于那些感兴趣的人来说,显然这不仅有利于多标签预测,而且已经观察到,对于单个实例/小批量预测,常规分类/回归的延迟也提高了3-4倍(详见https://issues.apache.org/jira/browse/SPARK-16198) )。

© www.soinside.com 2019 - 2024. All rights reserved.