randomForest类的err.rate组件是什么意思?

问题描述 投票:2回答:1

我正在使用软件包randomForest中的函数randomForesterr.rate是randomForest类的对象之一,它是

(仅分类)输入数据的预测矢量错误率,第i个元素是直到第i个树的所有树的(OOB)错误率。

您能否解释一下此组件的含义?非常感谢您的帮助!

我以数据集Sonar, Mines vs. Rocks作为代码示例。

library(mlbench)
data(Sonar)
library(boot)
library(randomForest)

n <- 208
ntrain <- 100
ntest <- 108
train.idx <- sample(1:n, ntrain, replace = FALSE)
train.set <- Sonar[train.idx, ]
test.set <- Sonar[-train.idx, ]

rf <- randomForest(Class ~ ., data = train.set, keep.inbag = TRUE, importance = TRUE)
head(rf$err.rate)

这里是代码的结果

             OOB         M         R
  [1,] 0.1891892 0.1500000 0.2352941
  [2,] 0.2931034 0.2307692 0.3437500
  [3,] 0.2739726 0.2647059 0.2820513
  [4,] 0.2911392 0.2894737 0.2926829
  [5,] 0.2413793 0.2682927 0.2173913
  [6,] 0.2555556 0.2142857 0.2916667
  [7,] 0.2553191 0.2444444 0.2653061
  [8,] 0.2268041 0.1956522 0.2549020
  [9,] 0.2783505 0.2608696 0.2941176
r random-forest
1个回答
0
投票

randomForest的一个组成部分是套袋,您可以从i棵树中获得共识预测。

随着增加树的数量,将在每个步骤中计算OOB错误。不能通过将从1棵树获得的预测与相对于该树的OOB样本进行比较来计算OOB误差,而是使用不使用该样本的树上的平均预测。我建议检查this for an overview

因此,在您的示例中,我们可以将其可视化:

library(ggplot2)
library(tidyr)

plotdf <- pivot_longer(data.frame(ntrees=1:nrow(rf$err.rate),rf$err.rate),-ntrees)
ggplot(plotdf,aes(x=ntrees,y=value,col=name)) + 
geom_line() + theme_bw()

enter image description here

M和R是该特定标签的预测误差线,而OOB(第一列)仅是两者的平均值。随着树数的增加,因为您可以从更多树中获得更好的预测,所以OOB错误会降低。

关于randomForest的好处是,您不需要交叉验证,因为OOB估计值通常具有指示性。下面我们可以尝试显示出相同的结果:

set.seed(12)
# split in 5 parts
trn = split(1:nrow(Sonar),sample(1:nrow(Sonar) %% 5))
sim = vector("list",5)
# the number of trees we incrementally grow
ntrees = c(1,20*(1:50)+1)

for(CV in 1:5){
idx = trn[[CV]]
train.set <- Sonar[-idx, ]
test.set <- Sonar[idx, ]
# first forest, n=1, but works
mdl <- randomForest(Class ~ ., data = train.set, ntree=1,
keep.inbag = TRUE, importance = TRUE,keep.forest=TRUE)
err_rate <- vector("numeric",51)
err_rate[1] <- mean(predict(mdl,test.set)!=test.set$Class)
#growing the tree
for(i in 1:50){
  mdl <- grow(mdl,10)
  err_rate[i+1] <- mean(predict(mdl,test.set)!=test.set$Class)
}
sim[[CV]] <- data.frame(ntrees=ntrees,err_rate=err_rate,CV=CV)
}
sim = do.call(rbind,sim)

#plot

ggplot(sim,aes(x=ntrees,y=err_rate)) + geom_line(aes(group=CV),alpha=0.2) + 
stat_summary(fun.y=mean,geom="line",col="blue")+theme_bw()

enter image description here

© www.soinside.com 2019 - 2024. All rights reserved.