使用未见的坚持数据的平均值更新 Party / Partykit 模型中的估计值

问题描述 投票:0回答:1

我想在数据子样本上创建一个决策树(使用 evtree,它在大型数据集上运行时间非常长)。

然后我想采用这个模型并使用保留数据的估计更新终端节点估计。这类似于 GRF 包中的“诚实”概念,其中通过查看保留数据来抵消采样模型构建中的偏差。这种情况下的最终结果将是最终模型,通常偏差较小,运行速度更快(训练输入较小)并且方差较低。 理想情况下,我能够采用新模型并推断新数据。

library(partykit)
mtcars
set.seed(12)
train = sample(nrow(mtcars), nrow(mtcars)/1.5)
sample_tree = ctree(mpg ~. , data = mtcars[train, ])
sample_tree %>% as.simpleparty

# Fitted party:
# [1] root
# |   [2] cyl <= 6: 23.755 (n = 11, err = 224.8)
# |   [3] cyl > 6: 15.380 (n = 10, err = # 42.1)

data.frame(node = predict(sample_tree, newdata = mtcars[-train, ], type = 'node'),
           prediction = mtcars[-train, ]$mpg) %>%
group_by(node) %>%
summarize(mpg = mean(prediction)) %>% as.list

 # $node
 # [1] 2 3
 # $mpg
 # [1] 24.31429 14.40000

在本例中,我将树中的节点 ID 2,3 分别更新为 24.31429 和 14.40000。

我尝试过的事情: 聊天 GPT 1000x,大量谷歌搜索,跳过各种步骤来弄清楚如何获取终端节点值,等等。


edit2:这似乎有效,但我并不 100% 理解为什么。谨慎行事

改编自 Achim Zeileis 的回答

# library(evtree)
set.seed(123)
train = sample(nrow(diamonds), nrow(diamonds)/20)
diamonds_evtree =  evtree("price ~ .", data = (diamonds %>% select(any_of(c("carat", "depth", "table", "price"))))[train, ],
                          maxdepth = 3L, niterations = 101)

diamonds_ctree = ctree(price ~ ., data = (diamonds %>% select(any_of(c("depth", "table", "price", "x", "y", "y"))))[train, ])

refit_constparty(as.constparty(diamonds_evtree), diamonds[-train,]) #fails
refit_constparty(diamonds_ctree, diamonds[-train,]) #works

as.constparty(diamonds_evtree)


refit_simpleparty <- function(object, newdata) {
  stopifnot(inherits(object, "constparty") | inherits(object, "simpleparty"))
  if(any(abs(object$fitted[["(weights)"]] - 1) > 0)) {
    stop("weights not implemented yet")
  }
  d <- model.frame(terms(object), data = newdata)
  ret <- party(object$node,
               data = d,
               fitted = data.frame(
                 "(fitted)" = fitted_node(object$node, d),
                 "(response)" = d[[1L]],
                 "(weights)" = 1L,
                 check.names = FALSE),
               terms = terms(object))
  as.simpleparty(ret)
}

# works with "arbitrary data"
refit_simpleparty(diamonds_ctree %>% as.simpleparty, newdata = diamonds)
r model tree nodes party
1个回答
0
投票

这可以通过使用新数据和拟合值设置新的

party()
并随后强制到
constparty
来完成。请参阅
vignette("constparty", package = "partykit")
了解更多详细信息和工作示例。

我编写了一个简短的函数,其中封装了必要的步骤:

refit_constparty <- function(object, newdata) {
  stopifnot(inherits(object, "constparty"))
  if(any(abs(object$fitted[["(weights)"]] - 1) > 0)) {
    stop("weights not implemented yet")
  }
  d <- model.frame(terms(object), data = newdata)
  y <- names(d)[1L]
  d <- d[, names(object$data), drop = FALSE]
  ret <- party(object$node,
    data = d,
    fitted = data.frame(
      "(fitted)" = fitted_node(object$node, d),
      "(response)" = d[[y]],
      "(weights)" = 1L,
      check.names = FALSE),
    terms = terms(object))
  as.constparty(ret)
}

请注意,调用

model.frame()
对于可能重新排序和转换变量(例如,动态设置因子或日志)非常重要。

对于您的数据分割,我获得以下信息:

refit_constparty(sample_tree, mtcars[-train,])
## Model formula:
## mpg ~ cyl + disp + hp + drat + wt + qsec + vs + am + gear + carb
## 
## Fitted party:
## [1] root
## |   [2] wt <= 2.32: NA (n = 0, err = NA)
## |   [3] wt > 2.32: 17.664 (n = 11, err = 135.8)
## 
## Number of inner nodes:    1
## Number of terminal nodes: 2

在节点 2 中,拟合值为 NA,因为没有观测值。 (也许我做错了什么,但我无法复制上面显示的拟合值。)

© www.soinside.com 2019 - 2024. All rights reserved.