我想在数据子样本上创建一个决策树(使用 evtree,它在大型数据集上运行时间非常长)。
然后我想采用这个模型并使用保留数据的估计更新终端节点估计。这类似于 GRF 包中的“诚实”概念,其中通过查看保留数据来抵消采样模型构建中的偏差。这种情况下的最终结果将是最终模型,通常偏差较小,运行速度更快(训练输入较小)并且方差较低。 理想情况下,我能够采用新模型并推断新数据。
library(partykit)
mtcars
set.seed(12)
train = sample(nrow(mtcars), nrow(mtcars)/1.5)
sample_tree = ctree(mpg ~. , data = mtcars[train, ])
sample_tree %>% as.simpleparty
# Fitted party:
# [1] root
# | [2] cyl <= 6: 23.755 (n = 11, err = 224.8)
# | [3] cyl > 6: 15.380 (n = 10, err = # 42.1)
data.frame(node = predict(sample_tree, newdata = mtcars[-train, ], type = 'node'),
prediction = mtcars[-train, ]$mpg) %>%
group_by(node) %>%
summarize(mpg = mean(prediction)) %>% as.list
# $node
# [1] 2 3
# $mpg
# [1] 24.31429 14.40000
在本例中,我将树中的节点 ID 2,3 分别更新为 24.31429 和 14.40000。
我尝试过的事情: 聊天 GPT 1000x,大量谷歌搜索,跳过各种步骤来弄清楚如何获取终端节点值,等等。
edit2:这似乎有效,但我并不 100% 理解为什么。谨慎行事
改编自 Achim Zeileis 的回答
# library(evtree)
set.seed(123)
train = sample(nrow(diamonds), nrow(diamonds)/20)
diamonds_evtree = evtree("price ~ .", data = (diamonds %>% select(any_of(c("carat", "depth", "table", "price"))))[train, ],
maxdepth = 3L, niterations = 101)
diamonds_ctree = ctree(price ~ ., data = (diamonds %>% select(any_of(c("depth", "table", "price", "x", "y", "y"))))[train, ])
refit_constparty(as.constparty(diamonds_evtree), diamonds[-train,]) #fails
refit_constparty(diamonds_ctree, diamonds[-train,]) #works
as.constparty(diamonds_evtree)
refit_simpleparty <- function(object, newdata) {
stopifnot(inherits(object, "constparty") | inherits(object, "simpleparty"))
if(any(abs(object$fitted[["(weights)"]] - 1) > 0)) {
stop("weights not implemented yet")
}
d <- model.frame(terms(object), data = newdata)
ret <- party(object$node,
data = d,
fitted = data.frame(
"(fitted)" = fitted_node(object$node, d),
"(response)" = d[[1L]],
"(weights)" = 1L,
check.names = FALSE),
terms = terms(object))
as.simpleparty(ret)
}
# works with "arbitrary data"
refit_simpleparty(diamonds_ctree %>% as.simpleparty, newdata = diamonds)
这可以通过使用新数据和拟合值设置新的
party()
并随后强制到 constparty
来完成。请参阅 vignette("constparty", package = "partykit")
了解更多详细信息和工作示例。
我编写了一个简短的函数,其中封装了必要的步骤:
refit_constparty <- function(object, newdata) {
stopifnot(inherits(object, "constparty"))
if(any(abs(object$fitted[["(weights)"]] - 1) > 0)) {
stop("weights not implemented yet")
}
d <- model.frame(terms(object), data = newdata)
y <- names(d)[1L]
d <- d[, names(object$data), drop = FALSE]
ret <- party(object$node,
data = d,
fitted = data.frame(
"(fitted)" = fitted_node(object$node, d),
"(response)" = d[[y]],
"(weights)" = 1L,
check.names = FALSE),
terms = terms(object))
as.constparty(ret)
}
请注意,调用
model.frame()
对于可能重新排序和转换变量(例如,动态设置因子或日志)非常重要。
对于您的数据分割,我获得以下信息:
refit_constparty(sample_tree, mtcars[-train,])
## Model formula:
## mpg ~ cyl + disp + hp + drat + wt + qsec + vs + am + gear + carb
##
## Fitted party:
## [1] root
## | [2] wt <= 2.32: NA (n = 0, err = NA)
## | [3] wt > 2.32: 17.664 (n = 11, err = 135.8)
##
## Number of inner nodes: 1
## Number of terminal nodes: 2
在节点 2 中,拟合值为 NA,因为没有观测值。 (也许我做错了什么,但我无法复制上面显示的拟合值。)