访问随机森林sparkR中的概率对象列

问题描述 投票:0回答:1

在SparkR中,我有这样的内容

rf <- spark.randomForest(train, formula, type = "classification")
pred <- predict(rf,test) 

执行中

head(pred) 

输出就是您在图像中看到的内容

enter image description here

Converting SparkR predictions to readable format (number or string)

我如何获得概率值?

r sparkr
1个回答
0
投票

您必须使用函数values为每个对象调用名为sparkR.callJMethod的Java方法。

t(sapply(collect(select(pred, "probability"))$probability, 
         FUN = function(x) sparkR.callJMethod(x, "values")))

这里是使用虹膜数据集的完整示例。目标值为Species,具有3个级别,并且总共有150个数据点。

df <- createDataFrame(iris)
model <- spark.randomForest(df, Species ~ ., type = "classification")
summary(model)

predictions <- predict(model, df)

local_prob <- collect(select(predictions, "probability"))$probability

t(sapply(local_prob, FUN = function(x) sparkR.callJMethod(x, "values")))

请注意,这些预测是收集的,如果数据集很大,则可能内存不足。如果是这样,则可以改用head

截断的输出:

       [,1]        [,2]        [,3]
  [1,] 0           0           1   
  [2,] 0           0           1   
  [3,] 0           0           1   
...  
[148,] 0           1           0   
[149,] 0.05        0.95        0   
[150,] 0.01805556  0.9819444   0  

© www.soinside.com 2019 - 2024. All rights reserved.