我正在构建一个函数,将验证库仑添加到我计划稍后建模的数据中。 (tidymodel::initial.split() 不会让我分层足够深,因此我的 DIY 方法)
该函数很简单,它将添加一列(在整个 mutate 中使用时),该列读取“验证”或“测试”,并将以 0.25/075 的比率随机应用它们。
我已经使该功能完全可用,但并不完美。我可以让它看到 group_by 的唯一方法是在我的执行代码中放置一个 n()
#exsample data
df<-tibble(x = seq(1,15,by=1)%>%rep(100),
true = (seq(1,100,by=1)*100*runif(1))%>%rep(each = 15),
measurment = lapply(rep(1,100),
function(data)
{data*rnorm(15, mean = 100*runif(1), sd = 2)}
)%>%unlist(),
z = c("a","b","c","d","e")%>%rep(20, each = 15),
ID = paste0("a", 1:20)%>%rep(each = 15*5)
)
validation_column <- function(data) {
# Generate a vector of labels based on the desired ratio
labels <- c(rep("training", round(0.75 * data)), rep("validation", round(0.25 * data)))
# Shuffle the labels
labels <- sample(labels)
return(labels)
}
df %>%
group_by(ID,z) %>%
mutate(validation = validation_column(n()))
我相信我需要将某种 if{} 合并到我的代码中,以在构建函数时触发 do() ,但我无法弄清楚具体如何。这是我的尝试基于 https://www.r-bloggers.com/2018/07/writing-pipe-friend-functions/但它仍然不起作用
欢迎任何建议,以便我将来可以构建更好的功能。
validation_column <- function(data) {
if (dplyr::is_grouped_df(data)) {
return(dplyr::do(data, ~ validation_column(.)))
}
# Generate a vector of labels based on the desired ratio
labels <- c(rep("training", round(0.75 * n(data))), rep("validation", round(0.25 * n(data))))
# Shuffle the labels
labels <- sample(labels)
return(labels)
}
df %>%
group_by(ID,z) %>%
mutate(validation = validation_column())
您的
rep(..round(data))
需要输入数字、长度。
试试这个:
validation_column <- function(data) {
# Generate a vector of labels based on the desired ratio
labels <- c(
rep("training", round(0.75 * length(data))), # length(data)
rep("validation", round(0.25 * length(data)))) # length(data)
# Shuffle the labels
labels <- sample(labels)
return(labels)
}
传递一列作为
mutate
输入:
> mutate(df, validation = validation_column(z))
# A tibble: 1,500 × 6
x true measurment z ID validation
<dbl> <dbl> <dbl> <chr> <chr> <chr>
1 1 30.8 26.0 a a1 training
2 2 30.8 25.6 a a1 training
3 3 30.8 27.5 a a1 training
4 4 30.8 26.0 a a1 training
5 5 30.8 26.4 a a1 validation
6 6 30.8 24.6 a a1 validation
7 7 30.8 27.2 a a1 validation
8 8 30.8 24.1 a a1 training
9 9 30.8 25.0 a a1 training
10 10 30.8 25.9 a a1 training
# ℹ 1,490 more rows
# ℹ Use `print(n = ...)` to see more rows