使用ggparty的决策树图的边缘上的小数位数

问题描述 投票:1回答:1

我想使用功能强大的ggparty软件包绘制决策树(由partykit软件包估算)。一切都很好,除了数字拆分变量的小数位数。我如何在breaks_label中格式化geom_edge_label(),例如在下图中将> 75.33333更改为> 75.3round()不起作用。我可能会通过一般的options(digits = 3)使用解决方法,但我想知道是否还有更直接的方法。

library("ggparty") 
data("WeatherPlay", package = "partykit")

sp_o <- partysplit(1L, index = 1:3)
sp_h <- partysplit(3L, breaks = 75 + 1/3)
sp_w <- partysplit(4L, index = 1:2)
pn <- partynode(1L, split = sp_o, kids = list(
    partynode(2L, split = sp_h, kids = list(
        partynode(3L, info = "yes"),
        partynode(4L, info = "no"))),
    partynode(5L, info = "yes"),
    partynode(6L, split = sp_w, kids = list(
        partynode(7L, info = "yes"),
        partynode(8L, info = "no")))))
py <- party(pn, WeatherPlay)

ggparty(py) +
    geom_edge() +
    # geom_edge_label() +
    geom_edge_label(mapping = aes(label = paste(breaks_label))) +
    geom_node_splitvar() +
    geom_node_info()

<< img src =“ https://image.soinside.com/eyJ1cmwiOiAiaHR0cHM6Ly9pLmltZ3VyLmNvbS90Unp2VlFvLnBuZyJ9” alt =“”>

reprex package(v0.3.0)在2020-03-05创建

r ggplot2 decision-tree
1个回答
0
投票

感谢您使用ggparty!

因此,我认为,当前版本的确没有直接解决方案。但是我一定会在将来实现!

通常,通过仅在节点的子集上使用几何,通常可以解决很多问题。正如您已经注意到的,breaks_label不是存储为数字,而是存储为字符,并在其前面带有不等号的一些可解析文本。因此,您必须使用诸如substr()之类的东西。

ggparty(py) +
  geom_edge() +
  geom_edge_label(id = -c(3, 4)) +
    geom_edge_label(mapping = aes(label = paste(substr(breaks_label, start = 1, stop = 15))),
                    id = c(3, 4)) +
  geom_node_splitvar() +
  geom_node_info() 

我还修改了内部函数之一以包括舍入功能,因此您可以从github获取并使用它。但是我还没有真正测试过,所以使用后果自负;)

library(devtools)
source_url("https://raw.githubusercontent.com/martin-borkovec/ggparty/martin/R/add_splitvar_breaks_index_new.R")

rounded_labels <- add_splitvar_breaks_index_new(party_object = py,
                                                plot_data = ggparty:::get_plot_data(py), 
                                                round_digits = 2)

ggparty(py) +
  geom_edge() +
  geom_edge_label(mapping = aes(label = unlist(rounded_labels)),
                  data = rounded_labels) +
  geom_node_splitvar() +
  geom_node_info()
© www.soinside.com 2019 - 2024. All rights reserved.