使用 parsnip 更新模型参数,而无需重新指定拟合函数的参数。

问题描述 投票:0回答:1

我正在使用Titanic数据集尝试使用Parsnip包。

library(titanic)
library(dplyr)
library(tidymodels)
library(rattle)
library(rpart.plot)
library(RColorBrewer)

train <- titanic_train %>%
  mutate(Survived = factor(Survived),
         Sex = factor(Sex),
         Embarked = factor(Embarked)) 

test <- titanic_test %>%
  mutate(Sex = factor(Sex),
         Embarked = factor(Embarked)) 

spec_obj <-
  decision_tree(mode = "classification") %>% 
  set_engine("rpart")
spec_obj

fit_obj <- 
  spec_obj %>% 
  fit(Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare + Embarked, data = train)
fit_obj

fancyRpartPlot(fit_obj$fit)

pred <- 
  fit_obj %>%
  predict(new_data = test)
pred

假设我想在我的模型函数中添加一些参数。

spec_obj <- update(spec_obj, min_n = 50, cost_complexity = 0)
fit_obj <-
  spec_obj %>% 
  fit(Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare + Embarked, data = train)
fit_obj
fancyRpartPlot(fit_obj$fit)

有什么方法可以规避第二次在 fit() 函数?

f <- as.formula("Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare + Embarked")
fit_obj <-
  spec_obj %>%
  fit(f, data = train)
fit_obj

不过,可能还有更好的方法?

r tidymodels
1个回答
0
投票

我认为最好的方法是创建一个小的封装函数,也许可以叫做 fit_titanic():

library(titanic)
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
library(tidymodels)
#> ── Attaching packages ────────────────────────────────────────────── tidymodels 0.1.0 ──
#> ✓ broom     0.5.5      ✓ recipes   0.1.10
#> ✓ dials     0.0.6      ✓ rsample   0.0.6 
#> ✓ ggplot2   3.3.0      ✓ tibble    3.0.1 
#> ✓ infer     0.5.1      ✓ tune      0.1.0 
#> ✓ parsnip   0.1.0      ✓ workflows 0.1.1 
#> ✓ purrr     0.3.4      ✓ yardstick 0.0.6
#> ── Conflicts ───────────────────────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard()  masks scales::discard()
#> x dplyr::filter()   masks stats::filter()
#> x dplyr::lag()      masks stats::lag()
#> x ggplot2::margin() masks dials::margin()
#> x recipes::step()   masks stats::step()

train <- titanic_train %>%
  mutate(Survived = factor(Survived),
         Sex = factor(Sex),
         Embarked = factor(Embarked)) 


spec1 <-
  decision_tree(mode = "classification") %>% 
  set_engine("rpart")

spec1
#> Decision Tree Model Specification (classification)
#> 
#> Computational engine: rpart

fit_titanic <- function(spec) {
  fit(spec, 
      Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare + Embarked, 
      data = train)
}


fit_titanic(spec1)
#> parsnip model object
#> 
#> Fit time:  17ms 
#> n= 891 
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#>   1) root 891 342 0 (0.61616162 0.38383838)  
#>     2) Sex=male 577 109 0 (0.81109185 0.18890815)  
#>       4) Age>=6.5 553  93 0 (0.83182640 0.16817360) *
#>       5) Age< 6.5 24   8 1 (0.33333333 0.66666667)  
#>        10) SibSp>=2.5 9   1 0 (0.88888889 0.11111111) *
#>        11) SibSp< 2.5 15   0 1 (0.00000000 1.00000000) *
#>     3) Sex=female 314  81 1 (0.25796178 0.74203822)  
#>       6) Pclass>=2.5 144  72 0 (0.50000000 0.50000000)  
#>        12) Fare>=23.35 27   3 0 (0.88888889 0.11111111) *
#>        13) Fare< 23.35 117  48 1 (0.41025641 0.58974359)  
#>          26) Embarked=S 63  31 0 (0.50793651 0.49206349)  
#>            52) Fare< 10.825 37  15 0 (0.59459459 0.40540541) *
#>            53) Fare>=10.825 26  10 1 (0.38461538 0.61538462)  
#>             106) Fare>=17.6 10   3 0 (0.70000000 0.30000000) *
#>             107) Fare< 17.6 16   3 1 (0.18750000 0.81250000) *
#>          27) Embarked=C,Q 54  16 1 (0.29629630 0.70370370) *
#>       7) Pclass< 2.5 170   9 1 (0.05294118 0.94705882) *

spec2 <- update(spec1, min_n = 50, cost_complexity = 0)

fit_titanic(spec2)
#> parsnip model object
#> 
#> Fit time:  10ms 
#> n= 891 
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#>  1) root 891 342 0 (0.61616162 0.38383838)  
#>    2) Sex=male 577 109 0 (0.81109185 0.18890815)  
#>      4) Age>=6.5 553  93 0 (0.83182640 0.16817360) *
#>      5) Age< 6.5 24   8 1 (0.33333333 0.66666667) *
#>    3) Sex=female 314  81 1 (0.25796178 0.74203822)  
#>      6) Pclass>=2.5 144  72 0 (0.50000000 0.50000000)  
#>       12) Fare>=23.35 27   3 0 (0.88888889 0.11111111) *
#>       13) Fare< 23.35 117  48 1 (0.41025641 0.58974359)  
#>         26) Embarked=S 63  31 0 (0.50793651 0.49206349)  
#>           52) Fare< 10.825 37  15 0 (0.59459459 0.40540541) *
#>           53) Fare>=10.825 26  10 1 (0.38461538 0.61538462) *
#>         27) Embarked=C,Q 54  16 1 (0.29629630 0.70370370) *
#>      7) Pclass< 2.5 170   9 1 (0.05294118 0.94705882) *

创建于2020-04-30 重读包 (v0.3.0)

© www.soinside.com 2019 - 2024. All rights reserved.