如何在嵌套 cv(mlr3 基准)期间提取外循环样本的预测(例如生存概率)?

问题描述 投票:0回答:1
  • 注意:此问题的代码/分析输出已根据要求使用 reprex() 进行编辑,以提高重现性

我希望获得一些关于提取分布预测(例如特定时间点的事件概率)、从嵌套简历的外循环中提取线性预测变量以及从嵌套简历的内循环中提取基线危险的指导(其中模型是在 mlr3 基准测试过程中开发的,以计算校准指数。

以肺部数据集为例:

rm(list = (ls(all=T)))
library(reprex)
library(mlr3)
library(mlr3learners)
library(mlr3extralearners)
library(mlr3pipelines)
library(mlr3tuning)
#> Loading required package: paradox
library(mlr3proba)
library(data.table)

# set up data
task_lung = tsk('lung')
d = task_lung$data()
d$time = ceiling(d$time/30.44)
task_lung = as_task_surv(d, time = 'time', event = 'status', id = 'lung')
po_encode = po('encode', method = 'treatment')
po_impute = po('imputelearner', lrn('regr.rpart'))
pre = po_encode %>>% po_impute
task = pre$train(task_lung)[[1]]


# learners
cph=lrn("surv.coxph")
# get baseline hazard estimates
comp.cph = as_learner(ppl(  
  "distrcompositor",
  learner = cph,
  estimator = "kaplan",
  form = "ph"
))



# Benchmark above 3 (in outer 4 folds)
set.seed(123)
BM1 = benchmark(benchmark_grid(task,
                      list(cph, comp.cph),
                      rsmp('cv', folds=4)),
                store_models =T)
#> INFO  [13:00:22.564] [mlr3] Running benchmark with 8 resampling iterations
#> INFO  [13:00:22.649] [mlr3] Applying learner 'surv.coxph' on task 'lung' (iter 1/4)
#> INFO  [13:00:22.716] [mlr3] Applying learner 'surv.coxph' on task 'lung' (iter 2/4)
#> INFO  [13:00:22.766] [mlr3] Applying learner 'surv.coxph' on task 'lung' (iter 3/4)
#> INFO  [13:00:22.825] [mlr3] Applying learner 'surv.coxph' on task 'lung' (iter 4/4)
#> INFO  [13:00:22.873] [mlr3] Applying learner 'distrcompositor.kaplan.surv.coxph.distrcompose' on task 'lung' (iter 1/4)
#> INFO  [13:00:23.029] [mlr3] Applying learner 'distrcompositor.kaplan.surv.coxph.distrcompose' on task 'lung' (iter 2/4)
#> INFO  [13:00:23.225] [mlr3] Applying learner 'distrcompositor.kaplan.surv.coxph.distrcompose' on task 'lung' (iter 3/4)
#> INFO  [13:00:23.365] [mlr3] Applying learner 'distrcompositor.kaplan.surv.coxph.distrcompose' on task 'lung' (iter 4/4)
#> INFO  [13:00:23.510] [mlr3] Finished benchmark

# extract (inner) models
mdl=mlr3misc::map(as.data.table(BM1)$learner, "model")

# check that the LP is for internal loop
# the following has 171 obs as expected from inner loops
mdl[[1]][["linear.predictors"]]
#>   [1]  1.18499215  0.31858695  0.99802300  0.58241827  0.15774088 -0.12313273
#>   [7]  0.27176230  0.31255547  0.87977282  1.66127703  1.27052489  0.32452650
#>  [13] -0.39147459  0.18128928  0.40221932  0.32179235  1.29488433  0.22421192
#>  [19]  0.36696516  0.02351139  0.03435003  0.30268183  0.46156831  0.67854111
#>  [25] -0.38466200  0.18029125 -0.30794694  1.81985726  0.72143148  0.73993999
#>  [31] -0.19820495 -0.61578221  1.41832763  0.61361444  0.26284730  1.59711534
#>  [37]  0.08964261 -0.39070785  0.34200863  0.63126957  1.02630566  0.01239987
#>  [43]  0.23198859 -0.26151613 -0.39089134  0.19280407 -0.87206586 -0.62197655
#>  [49]  0.88013192  0.08972523 -0.74319591 -1.03042024 -0.27852335  0.08151983
#>  [55]  1.71881259 -0.67763599  1.11294971  0.88264769  0.10990293 -0.02842130
#>  [61]  0.37947992  0.49182045  0.77795647  0.71053557  0.65215179  0.27799555
#>  [67] -0.37213111  1.88280007  1.19325541 -0.12267881  1.15952023 -0.94928614
#>  [73]  0.86114411  0.31579099  1.91908641  0.77277223 -0.04938800  0.39655017
#>  [79] -0.02114651  1.77783288  1.01903263  0.67158262  0.49602602 -0.53871485
#>  [85]  1.13274466  0.24812544  1.61907952 -0.01141713  0.85803147  0.94607458
#>  [91] -0.41149162  0.35978972  1.53106530  1.31347748  0.12968634  0.95718654
#>  [97]  1.21293718  0.90307591 -0.10667084  0.87119681  0.94642450 -0.40357218
#> [103]  0.21266674  0.54770469  0.60351288  0.61024865 -1.15985267 -0.60745992
#> [109]  0.64926122  0.94255144  0.10625227 -0.17035191  0.35624501  0.78856855
#> [115]  0.82062314  0.64819105  0.29187039  0.83125774  1.40013302  0.64771200
#> [121]  1.15793176  0.62013293  0.13332034  0.11379916  0.53277392  0.67303587
#> [127]  0.07299108 -0.22845207 -0.37492090 -0.22763493  0.18024949 -0.45301382
#> [133] -0.11000587  0.10124783  0.11111948  0.43462859  1.19243908  0.87027351
#> [139]  1.32904450  0.14259790  0.39245370  0.62142923 -0.33833996  0.48248721
#> [145]  0.94495937  1.09431650  0.78709427  0.18538695 -0.26797215 -0.40705985
#> [151]  1.49430202  1.47283382 -1.12637695 -0.16080892 -0.35788332  0.57572753
#> [157] -0.85507993 -0.05270265 -0.11091513  0.60276879 -0.35633053  0.47495284
#> [163]  0.81178544  0.73027874 -0.43725719 -0.22938202 -0.28038346  0.54849467
#> [169]  0.06449779  1.38581779 -0.29679293

# extract baseline hazard 
# Assuming that these are from inner loops (which is required for distributional prediction)
mdl[[5]][["distrcompositor.kaplan"]][["model"]][["std.chaz"]]
#>  [1] 0.01547223 0.02062259 0.02649866 0.03237961 0.03735763 0.04532140
#>  [7] 0.05320536 0.05999788 0.06527124 0.07576031 0.08515480 0.10010706
#> [13] 0.10877590 0.11161182 0.12888880 0.13672413 0.13672413 0.14575496
#> [19] 0.15721555 0.16497494 0.17407800 0.19293410 0.20412743 0.23869255
#> [25] 0.27111447 0.32365577 0.36404785 0.36404785 0.49360100 0.49360100


# Extract outer learners based on a previous question: "mlr3, benchmarking and nested resampling: how to extract a tuned model from a benchmark object to calculate feature importance, the suggested code to extract learners fitted in the outer loop" @ https://github.com/mlr-org/mlr3/issues/601

data = as.data.table(BM1)
outer_learners = mlr3misc::map(data$learner, "learner"); outer_learners  #outer_learners is null
#> [[1]]
#> NULL
#> 
#> [[2]]
#> NULL
#> 
#> [[3]]
#> NULL
#> 
#> [[4]]
#> NULL
#> 
#> [[5]]
#> NULL
#> 
#> [[6]]
#> NULL
#> 
#> [[7]]
#> NULL
#> 
#> [[8]]
#> NULL


# Based on the following https://github.com/mlr-org/mlr3/issues/601, it may be possible to save the best model from inner cv and apply this model for prediction on a new/outer dataset (This is not working too)
measures = list (
  msr('surv.cindex', id='cindex'),
  msr('surv.graf', id='brier'))
best_mdl <- BM1$score(measures)[learner_id=='encode.scale.surv.glmnet.tuned',][cindex==max(cindex)]$learners
#> Warning in max(cindex): no non-missing arguments to max; returning -Inf

reprex 包于 2024 年 5 月 3 日创建(v2.0.1)

会议信息
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 4.1.1 (2021-08-10)
#>  os       macOS Big Sur 10.16         
#>  system   x86_64, darwin17.0          
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       Australia/Adelaide          
#>  date     2024-05-03                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package           * version     date       lib
#>  backports           1.4.1       2021-12-13 [1]
#>  bbotk               0.7.3.9000  2023-11-22 [1]
#>  checkmate           2.3.1       2023-12-04 [1]
#>  cli                 3.6.2       2023-12-11 [1]
#>  codetools           0.2-18      2020-11-04 [2]
#>  colorspace          2.1-0       2023-01-23 [1]
#>  crayon              1.4.1       2021-02-08 [2]
#>  data.table        * 1.15.4      2024-03-30 [1]
#>  dictionar6          0.1.3       2021-09-13 [1]
#>  digest              0.6.35      2024-03-11 [1]
#>  distr6              1.8.4       2024-05-02 [1]
#>  dplyr               1.1.3       2023-09-03 [1]
#>  evaluate            0.23        2023-11-01 [1]
#>  fansi               1.0.6       2023-12-08 [1]
#>  fastmap             1.1.0       2021-01-25 [2]
#>  fs                  1.5.0       2020-07-31 [2]
#>  future              1.33.2      2024-03-26 [1]
#>  future.apply        1.11.2      2024-03-28 [1]
#>  generics            0.1.3       2022-07-05 [1]
#>  ggplot2             3.5.1       2024-04-23 [1]
#>  globals             0.16.3      2024-03-08 [1]
#>  glue                1.7.0       2024-01-09 [1]
#>  gtable              0.3.5       2024-04-22 [1]
#>  highr               0.9         2021-04-16 [2]
#>  htmltools           0.5.6       2023-08-10 [1]
#>  knitr               1.33        2021-04-24 [2]
#>  lattice             0.20-44     2021-05-02 [2]
#>  lgr                 0.4.4       2022-09-05 [1]
#>  lifecycle           1.0.4       2023-11-07 [1]
#>  listenv             0.9.1       2024-01-29 [1]
#>  magrittr            2.0.3       2022-03-30 [1]
#>  Matrix              1.3-4       2021-06-01 [2]
#>  mlr3              * 0.19.0      2024-04-24 [1]
#>  mlr3extralearners * 0.7.1       2023-11-24 [1]
#>  mlr3learners      * 0.5.7.9000  2023-11-24 [1]
#>  mlr3misc            0.15.0      2024-04-10 [1]
#>  mlr3pipelines     * 0.5.0-9000  2023-11-22 [1]
#>  mlr3proba         * 0.6.1       2024-05-02 [1]
#>  mlr3tuning        * 0.19.1.9000 2023-11-22 [1]
#>  mlr3viz             0.8.0       2024-03-05 [1]
#>  munsell             0.5.1       2024-04-01 [1]
#>  ooplah              0.2.0       2022-01-21 [1]
#>  palmerpenguins      0.1.1       2022-08-15 [1]
#>  paradox           * 0.11.1-9000 2023-11-22 [1]
#>  parallelly          1.37.1      2024-02-29 [1]
#>  param6              0.2.4       2023-11-22 [1]
#>  pillar              1.9.0       2023-03-22 [1]
#>  pkgconfig           2.0.3       2019-09-22 [2]
#>  R6                  2.5.1       2021-08-19 [1]
#>  Rcpp                1.0.12      2024-01-09 [1]
#>  reprex            * 2.0.1       2021-08-05 [1]
#>  RhpcBLASctl         0.23-42     2023-02-11 [1]
#>  rlang               1.1.3       2024-01-10 [1]
#>  rmarkdown           2.10        2021-08-06 [2]
#>  rpart               4.1-15      2019-04-12 [2]
#>  rstudioapi          0.15.0      2023-07-07 [1]
#>  scales              1.3.0       2023-11-28 [1]
#>  sessioninfo         1.1.1       2018-11-05 [2]
#>  set6                0.2.6       2023-11-22 [1]
#>  stringi             1.7.3       2021-07-16 [2]
#>  stringr             1.5.0       2022-12-02 [1]
#>  survival            3.5-7       2023-08-14 [1]
#>  survivalmodels      0.1.191     2024-03-19 [1]
#>  tibble              3.2.1       2023-03-20 [1]
#>  tidyselect          1.2.0       2022-10-10 [1]
#>  utf8                1.2.4       2023-10-22 [1]
#>  uuid                1.2-0       2024-01-14 [1]
#>  vctrs               0.6.5       2023-12-01 [1]
#>  withr               3.0.0       2024-01-16 [1]
#>  xfun                0.25        2021-08-06 [2]
#>  yaml                2.2.1       2020-02-01 [2]
#>  source                                    
#>  CRAN (R 4.1.0)                            
#>  https://mlr-org.r-universe.dev (R 4.1.1)  
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  Github (xoopR/distr6@a7c01f7)             
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  Github (mlr-org/mlr3extralearners@6e2af9e)
#>  Github (mlr-org/mlr3learners@86f19eb)     
#>  CRAN (R 4.1.1)                            
#>  https://mlr-org.r-universe.dev (R 4.1.1)  
#>  https://mlr-org.r-universe.dev (R 4.1.1)  
#>  https://mlr-org.r-universe.dev (R 4.1.1)  
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.2)                            
#>  https://mlr-org.r-universe.dev (R 4.1.1)  
#>  CRAN (R 4.1.1)                            
#>  Github (xoopR/param6@0fa3577)             
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  Github (xoopR/set6@a901255)               
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.0)                            
#> 
#> [1] /Users/Lee/Library/R/x86_64/4.1/library
#> [2] /Library/Frameworks/R.framework/Versions/4.1/Resources/library

我非常感谢任何有关如何在 mlr3 中的嵌套 cv 基准测试过程中提取内循环上的相关结果并将这些结果应用于外循环样本上的预测任务的建议/建议。

mlr3
1个回答
0
投票

简短回答

...原标题:您已经有了外循环的预测,它位于

data$prediction
槽中。

代码详情/建议

  • 使用
    reprex()
    时,最好将其输出到 GitHub 问题,因为这里没有太大区别。
learners
cph=lrn("surv.coxph")
# get baseline hazard estimates

备注:

  1. distrcompositor
    组成 S(t)(生存预测),您没有得到基线 S_0(t)。使用 CoxPH 模型,您不需要编写 S(t),
    survival
    包默认使用的 S(t)(Breslow)就足够了。我想稍后你只需要经过训练的 Kaplan 拟合 (
    survfit
    ),所以你可以只使用
    lrn("surv.kaplan")
  2. 参数
    overwrite
    默认为
    FALSE
    ,因此您无需在此处使用
    distrcompositor
    添加任何内容。只是让你的生活变得更加困难,让你之后在
    PipeOp
    中找到学习者(如果你需要的话)。
comp.cph = as_learner(ppl(  
  "distrcompositor",
  learner = cph,
  estimator = "kaplan",
  form = "ph"
))
comp.cph # see that `distrcompose.overwrite=FALSE`, so SAME S(t) as `cph` learner
lasso = as_learner(
  po("encode") %>>% 
  po('scale') %>>% # standardize = TRUE by default in glmnet so this is re-done
    # Note: good to write the arguments, eg `tuner = ...`
  auto_tuner(
    tuner = tnr("grid_search", resolution = 10, batch_size =10),
    learner = lrn("surv.glmnet", alpha=1, s = to_tune(0.005, 1)),
    resampling = rsmp("cv", folds=2),
    measure = msr(c("surv.cindex")),
    store_models = TRUE,
    terminator = trm("stagnation", iters=50, threshold=0.01))
  )

我假设上面提取的基线危险也基于内部样本(这是 cph/lasso 的外部样本(分布)预测所必需的)

是的,因为:

mdl[[9]]$distrcompositor.kaplan$model$n (171 = #samples in 3 out of 4 folds)

用于提取安装在外循环中的学习者的建议代码不起作用:

这是一个不同的示例,我在运行它时遇到了不同的错误:

data = as.data.table(BM1)
outer_learners = mlr3misc::map(data$learner, "learner")
Error in `[[.R6`(X[[i]], ...) :
R6 class LearnerSurvCoxPH/LearnerSurv/Learner/R6 does not have slot 'learner'!

但是看,

data$learner
是您要求的学习者列表,即学习者的“调整”/“训练”版本(在 4 倍中的 3 倍上),准备用于进行预测。

...也许可以从内部CV中保存最佳模型(是

data$learner
)并应用此模型对新/外部数据集进行预测:

measures = list(msr('surv.cindex', id='cindex'), msr('surv.graf', id='brier'))
BM1$score(measures) # this works for me, no problem

我非常感谢任何关于 如何从 mlr3 中的嵌套 cv 基准测试程序中提取内循环上的相关结果并将这些结果应用于外循环样本上的预测任务的提示/建议。

外环样本预测如下(#samples = 57 = 228/4,即 1 个外层):

data$prediction
  • 请将所有与 mlr3 相关的软件包更新到最新的 CRAN 或 github 版本,主要是 mlr3proba
    mlr3extralearners
    mlr3pipelines
    paradox
    ,因为您的版本有点旧,可能存在一些差异。请注意,所有这些软件包最近都有一些核心更新,以使事情变得更好更快,等等,总的来说,这样做很好。
© www.soinside.com 2019 - 2024. All rights reserved.