如何在R中使用循环和并行得到同样的结果?

问题描述 投票:1回答:1

我测试训练数据对分类准确性的影响。例如,我使用虹膜数据。我注意到,我从33次迭代中得到了最好的准确性。我想使用迭代中的训练集(iristrain)进行进一步分析。我不知道如何重现它。我不想每次迭代都保存训练集,因为它太大了。我只想从33次迭代中得到这个集。我试过这样做: clusterSetRNGStream() 然后我在一个循环中使用相同的种子,但它没有给出相同的结果。

library(randomForest)
library(caret)
library(foreach)
library(doParallel)

results_overall <- data.frame()

cores = detectCores()
cl = makeCluster(cores - 1)

registerDoParallel(cl)
res <- foreach(i = 1:50, .packages = c("caret", "randomForest"), .combine = rbind) %dopar% {

trainIndex <- caret::createDataPartition(iris$Species, p = 0.5, list = FALSE)
irisTrain <- iris[ trainIndex,]
irisTest  <- iris[-trainIndex,]

model <- randomForest(x = irisTrain[,c(1:4)], y = irisTrain[,5], importance = TRUE,
                                            replace = TRUE, mtry = 4, ntree = 500, na.action=na.omit,
                                            do.trace = 100, type = "classification")

pred_test <- predict(model, irisTest[,c(1:4)])
con.mat_test <- confusionMatrix(pred_test, irisTest[,5], mode ="everything")

results_overall <- rbind(results_overall, con.mat_test[["overall"]])

return(tibble::tribble(~iteration, ~overall, 
                       i, results_overall))
}
stopCluster(cl)
r foreach parallel-processing caret doparallel
1个回答
3
投票

你的迭代给出的差异是基于随机抽样进行的。caret::createDataPartition. 为了使其可复制,您可以使用 doRNG 为此编写的包--非常感谢@HenrikB对我的启迪。

编辑。 修正了foreach函数(没有改变结果)

invisible(suppressPackageStartupMessages(
    lapply(c("data.table", "randomForest", "caret", "foreach", 
             "doRNG", "rngtools", "doParallel"),
           require, character.only = TRUE)))
cores = detectCores()
cl = makeCluster(cores - 1)
registerDoParallel(cl)
res <- foreach(i = 1:50, .packages = c("caret", "randomForest", "data.table"), .combine = rbind,
               .options.RNG=1234) %dorng% {
                   trainIndex <- caret::createDataPartition(iris$Species, p = 0.5, list = FALSE)
                   irisTrain <- iris[ trainIndex,]
                   irisTest  <- iris[-trainIndex,]
                   model <- randomForest(x = irisTrain[,c(1:4)], y = irisTrain[,5], importance = TRUE,
                                         replace = TRUE, mtry = 4, ntree = 500, na.action=na.omit,
                                         do.trace = 100, type = "classification")
                   pred_test <- predict(model, irisTest[,c(1:4)])
                   con.mat_test <- confusionMatrix(pred_test, irisTest[,5], mode ="everything")
                   return(data.table(Iteration=i, t(con.mat_test[["overall"]])))
               }
stopCluster(cl)
seeds <-  attr(res, 'rng')
res[which.min(Accuracy),]
#>    Iteration  Accuracy Kappa AccuracyLower AccuracyUpper AccuracyNull
#> 1:         6 0.9066667  0.86     0.8171065     0.9616461    0.3333333
#>    AccuracyPValue McnemarPValue
#> 1:    4.39803e-25           NaN

best.seed <- res[which.min(Accuracy),]$Iteration

rngtools::setRNG(seeds[[best.seed]])
trainIndex <- caret::createDataPartition(iris$Species, p = 0.5, list = FALSE)
irisTrain <- iris[ trainIndex,]
irisTest  <- iris[-trainIndex,]

model <- randomForest(x = irisTrain[,c(1:4)], y = irisTrain[,5], importance = TRUE,
                      replace = TRUE, mtry = 4, ntree = 500, na.action=na.omit,
                      do.trace = 100, type = "classification")
#> ntree      OOB      1      2      3
#>   100:   4.00%  0.00%  4.00%  8.00%
#>   200:   2.67%  0.00%  4.00%  4.00%
#>   300:   2.67%  0.00%  4.00%  4.00%
#>   400:   2.67%  0.00%  4.00%  4.00%
#>   500:   4.00%  0.00%  4.00%  8.00%
pred_test <- predict(model, irisTest[,c(1:4)])
con.mat_test <- confusionMatrix(pred_test, irisTest[,5], mode ="everything")
con.mat_test[["overall"]]
#>       Accuracy          Kappa  AccuracyLower  AccuracyUpper   AccuracyNull 
#>   9.066667e-01   8.600000e-01   8.171065e-01   9.616461e-01   3.333333e-01 
#> AccuracyPValue  McnemarPValue 
#>   4.398030e-25            NaN

创建于2020-05-05 重读包 (v0.3.0)


0
投票

试试这个 set.seed(1) 摆在 training 模型。更多信息请致电 R 此命令 ?set.seed()在那里,它描述得相当好。该 seed number 是生成随机数序列的起点。更多信息 此处

© www.soinside.com 2019 - 2024. All rights reserved.