获取 rpart 节点中的观测值(即:CART)

问题描述 投票:0回答:6

我想检查到达 rpart 决策树中某个节点的所有观察结果。例如,在以下代码中:

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit

n= 81 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 81 17 absent (0.79012346 0.20987654)  
   2) Start>=8.5 62  6 absent (0.90322581 0.09677419)  
     4) Start>=14.5 29  0 absent (1.00000000 0.00000000) *
     5) Start< 14.5 33  6 absent (0.81818182 0.18181818)  
      10) Age< 55 12  0 absent (1.00000000 0.00000000) *
      11) Age>=55 21  6 absent (0.71428571 0.28571429)  
        22) Age>=111 14  2 absent (0.85714286 0.14285714) *
        23) Age< 111 7  3 present (0.42857143 0.57142857) *
   3) Start< 8.5 19  8 present (0.42105263 0.57894737) *

我想查看节点 (5) 中的所有观测值(即:Start>=8.5 & Start 的 33 个观测值< 14.5). Obviously I could manually get to them. But I would like to have some function like (say) "get_node_date". For which I could just run get_node_date(5) - and get the relevant observations.

关于如何解决这个问题有什么建议吗?

r decision-tree rpart
6个回答
4
投票

似乎没有这样的函数可以从特定节点提取观测值。我会按如下方式解决它:首先确定您感兴趣的节点使用哪个规则。您可以使用

path.rpart
来解决它。然后,您可以一个接一个地应用规则来提取观察结果。

将此方法作为函数:

get_node_date <- function(tree = fit, node = 5){
  rule <- path.rpart(tree, node)
  rule_2 <- sapply(rule[[1]][-1], function(x) strsplit(x, '(?<=[><=])(?=[^><=])|(?<=[^><=])(?=[><=])', perl = TRUE))
  ind <- apply(do.call(cbind, lapply(rule_2, function(x) eval(call(x[2], kyphosis[,x[1]], as.numeric(x[3]))))), 1, all)
  kyphosis[ind,]
  }

对于节点 5,您得到:

get_node_date()

 node number: 5 
   root
   Start>=8.5
   Start< 14.5
   Kyphosis Age Number Start
2    absent 158      3    14
10  present  59      6    12
11  present  82      5    14
14   absent   1      4    12
18   absent 175      5    13
20   absent  27      4     9
23  present  96      3    12
26   absent   9      5    13
28   absent 100      3    14
32   absent 125      2    11
33   absent 130      5    13
35   absent 140      5    11
37   absent   1      3     9
39   absent  20      6     9
40  present  91      5    12
42   absent  35      3    13
46  present 139      3    10
48   absent 131      5    13
50   absent 177      2    14
51   absent  68      5    10
57   absent   2      3    13
59   absent  51      7     9
60   absent 102      3    13
66   absent  17      4    10
68   absent 159      4    13
69   absent  18      4    11
71   absent 158      5    14
72   absent 127      4    12
74   absent 206      4    10
77  present 157      3    13
78   absent  26      7    13
79   absent 120      2    13
81   absent  36      4    13

1
投票

partykit
包还为此提供了一个罐装解决方案。您只需将
rpart
对象转换为
party
类,即可使用其统一接口来处理树。然后就可以使用
data_party()
功能了。

使用问题中的

fit
并加载
library("partykit")
,您可以首先将
rpart
树强制为
party

pfit <- as.party(fit)
plot(pfit)

按照你想要的方式提取数据只有两个小麻烦:(1)原始拟合中的

model.frame()
总是在强制中丢失,需要手动重新附加。 (2) 节点使用不同的编号方案。您现在需要节点 4(而不是 5)。

pfit$data <- model.frame(fit)
data4 <- data_party(pfit, 4)
dim(data4)
## [1] 33  5
head(data4)
##    Kyphosis Age Start (fitted) (response)
## 2    absent 158    14        7     absent
## 10  present  59    12        8    present
## 11  present  82    14        8    present
## 14   absent   1    12        5     absent
## 18   absent 175    13        7     absent
## 20   absent  27     9        5     absent

另一种方法是从节点 4 开始对子树进行子集化,然后从中获取数据:

pfit4 <- pfit[4]
plot(pfit4)

然后

data_party(pfit4)
给你与上面的
data4
相同的结果。并且
pfit4$data
为您提供没有
(fitted)
节点和预测
(response)
的数据。


1
投票

还有另一种方式,通过查找任何特定节点的所有终端节点并返回调用中使用的数据子集来实现。

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)

head(subset.rpart(fit, 5))
#    Kyphosis Age Number Start
# 2    absent 158      3    14
# 10  present  59      6    12
# 11  present  82      5    14
# 14   absent   1      4    12
# 18   absent 175      5    13
# 20   absent  27      4     9


subset.rpart <- function(tree, node = 1L) {
  data <- eval(tree$call$data, parent.frame(1L))
  wh <- sapply(as.integer(rownames(tree$frame)), parent)
  wh <- unique(unlist(wh[sapply(wh, function(x) node %in% x)]))
  data[rownames(tree$frame)[tree$where] %in% wh[wh >= node], ]
}

parent <- function(x) {
  if (x[1] != 1)
    c(Recall(if (x %% 2 == 0L) x / 2 else (x - 1) / 2), x) else x
}

1
投票

rpart 中训练观测的终端节点分配可以从

$where
获得:

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit$where

作为函数:

get_node <- function(rpart.object=fit, data=kyphosis, node.number=5) {
  data[which(fit$where == node.number),]  
}
get_node()

这仅适用于训练观察,不适用于新观察。并且不适用于内部节点。


0
投票

rpart 返回 rpart.object 元素,其中包含您需要的信息:

require(rpart)
fit2 <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit2

get_node_date <-function(nodeId,fit)
{  
  fit$frame[toString(nodeId),"n"]
}


for (i in c(1,2,4,5,10,11,22,23,3) )
  cat(get_node_date(i,fit2),"\n")

0
投票

另一种方法是从给定节点

n
查找所有子节点。 我们可以使用
rpart
对象来查找它们。将此信息与 数据集中每个点的结束节点(在这个问题中是脊柱后凸), 如上所述,从
fit$where
获得 @rawar,您可以获取给定节点涉及的数据集中的所有点,而不是 必然是结局。

步骤摘要为:

  1. 查找节点号并识别那些末端节点(“叶子”)。这 信息可以在 rpart 对象的
    frame
    元素中找到。
  2. 计算给定节点
    n
    的所有子节点。它们可以被计算 递归地使用节点
    n
    的子节点是
    2*n
    2*n+1
    ,如
    rpart.plot
    包中所述 小插图第26页
  3. 一旦知道从节点
    n
    垂下的叶子,就可以选取节点中的点 这些叶子中的数据集使用 rpart 对象的
    where
    元素

我在函数

get_children_nodes()
中编写了步骤 1 和 2,在函数中编写了步骤 3
get_node_data()
回答了所提出的问题。在这个函数中,我已经 包括打印相应节点规则 (
rule = TRUE
) 的可能性 得到与@datamineR相同的答案

library(rpart)
library(rpart.plot)

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
get_children_nodes <- function(tree, node){
  # check if node is a leaf based in rpart object (tree) information (step 1)
  z <- tree$frame
  is_leaf <- z$var == "<leaf>"
  nodes <- as.integer(row.names(z))
  
  # find recursively all children nodes (step 2)
  find_children <- function(node, nodes, is_leaf){
    condition <- is_leaf[nodes == node]
    if (condition) {
      # If node is leaf, return it
      v1 <- node
    } else {
      # If node is not leaf, search children leaf recursively
      v1 <- c(find_children(2 * node, nodes, is_leaf), 
              find_children(2 * node + 1, nodes, is_leaf))
    } 
    return(v1)
  }
  
  return(find_children(node, nodes, is_leaf))
}
get_node_data <- function(dataset, tree, node, rule = FALSE) {
  # Find children nodes of the node
  children_nodes <- get_children_nodes(tree, node)
  # match those nodes into the rpart node identification
  id_nodes <- which(as.integer(row.names(tree$frame)) %in% children_nodes)
  # Get the elements in the datset involved in the node (step 3)
  filtered_dataset <- dataset[tree$where %in% id_nodes, ]
  
  # print the node rule if needed
  if(rule) {
    rpart::path.rpart(tree, node, pretty = TRUE)
    cat("  \n")
  }
  return( filtered_dataset)
}
# Get the children nodes
get_children_nodes(fit, 5)
#> [1] 10 22 23
# Complete function to return the elements of node 5
get_node_data(kyphosis, fit, 5, rule = TRUE) 
#> 
#>  node number: 5 
#>    root
#>    Start>=8.5
#>    Start< 14.5
#> 
#>    Kyphosis Age Number Start
#> 2    absent 158      3    14
#> 10  present  59      6    12
#> 11  present  82      5    14
#> 14   absent   1      4    12
#> 18   absent 175      5    13
#> 20   absent  27      4     9
#> 23  present  96      3    12
#> 26   absent   9      5    13
#> 28   absent 100      3    14
#> 32   absent 125      2    11
#> 33   absent 130      5    13
#> 35   absent 140      5    11
#> 37   absent   1      3     9
#> 39   absent  20      6     9
#> 40  present  91      5    12
#> 42   absent  35      3    13
#> 46  present 139      3    10
#> 48   absent 131      5    13
#> 50   absent 177      2    14
#> 51   absent  68      5    10
#> 57   absent   2      3    13
#> 59   absent  51      7     9
#> 60   absent 102      3    13
#> 66   absent  17      4    10
#> 68   absent 159      4    13
#> 69   absent  18      4    11
#> 71   absent 158      5    14
#> 72   absent 127      4    12
#> 74   absent 206      4    10
#> 77  present 157      3    13
#> 78   absent  26      7    13
#> 79   absent 120      2    13
#> 81   absent  36      4    13

创建于 2023-08-14,使用 reprex v2.0.2

© www.soinside.com 2019 - 2024. All rights reserved.