我正在尝试让 R 函数(更加)高效。在下面找到一个工作示例。
smsurv <- function(Time,Status,X,beta,w,model){
death_point <- sort(unique(subset(Time, Status==1)))
if(model=='ph') coxexp <- exp((beta)%*%t(X[,-1]))
n <- length(death_point)
lambda <- numeric(n)
for(i in 1: n){
if(model=='ph') temp <- sum(as.numeric(Time>=death_point[i])*w*drop(coxexp))
if(model=='aft') temp <- sum(as.numeric(Time>=death_point[i])*w)
lambda[i] <- sum(Status*as.numeric(Time==death_point[i]))/temp
}
HHazard <- numeric()
for(i in 1:length(Time)){
HHazard[i] <- sum(as.numeric(Time[i]>=death_point)*lambda)
if(Time[i]>max(death_point))HHazard[i] <- Inf
if(Time[i]<min(death_point))HHazard[i] <- 0
}
survival <- exp(-HHazard)
list(survival=survival)
}
nr_obs = 50000
Time_input <- rnorm(nr_obs, mean = 100, sd = 36)
Status_input <- sample(c(0,1), replace=TRUE, size=nr_obs)
w_input <- Status_input
# Let's suppose there are 9 variables (first column denotes the intercept)
n_variables <- 9
X_input <- matrix(rnorm(nr_obs*n_variables),nr_obs)
X_input <- cbind(Intercept = rep(1, nrow(X_input)), X_input)
beta_input <- runif(n_variables, min = -1, max = 1)
model_input <- "ph"
output <- smsurv(Time_input,Status_input,X_input,beta_input,w_input,model_input)
我已经尝试用 lapply 和 sapply 替换 for 循环,但这实际上使函数变得更慢:
smsurv2 <- function(Time,Status,X,beta,w,model){
death_point <- sort(unique(subset(Time, Status==1)))
if(model=='ph') coxexp <- exp((beta)%*%t(X[,-1]))
if(model_input=='ph') lambda =unlist(lapply(death_point, function(z) sum(Status_input*as.numeric(Time_input==z))/ sum(as.numeric(Time_input>=z)*w_input*drop(coxexp))))
if(model=='aft') lambda =unlist( lapply(death_point, function(z) sum(Status_input*as.numeric(Time_input==z))/ sum(as.numeric(Time_input>=z)*w_input)))
HHazard <- unlist(lapply(Time, function(t) {sum(as.numeric(t>=death_point)*lambda)}))
HHazard[Time > max(death_point)] <- Inf
HHazard[Time < min(death_point)] <- 0
survival <- exp(-HHazard)
list(survival=survival)
}
smsurv3 <- function(Time, Status, X, beta, w, model){
death_point <- sort(unique(subset(Time, Status==1)))
if(model=='ph') coxexp <- exp((beta)%*%t(X[,-1]))
lambda <- sapply(death_point, function(dp) {return(sum(Status*as.numeric(Time==dp))/sum(as.numeric(Time>=dp)*w*drop(coxexp)))})
HHazard <- sapply(Time, function(t){return(sum(as.numeric(t>=death_point)*lambda))})
HHazard[Time > max(death_point)] <- Inf
HHazard[Time < min(death_point)] <- 0
survival <- exp(-HHazard)
list(survival=survival)
}
考虑到这一点,有人还有其他我可以尝试的建议吗?我最近阅读了有关 rcpp 包的信息,但我不确定如何用 C 代码替换 for 循环。非常欢迎任何建议。
函数
smsurv2
速度更快,并且结果与问题的函数相同。
smsurv2 <- function(Time, Status, X, beta, w, model){
death_point <- Time[Status == 1] |> unique() |> sort()
n <- length(death_point)
lambda <- numeric(n)
if(model == 'ph') {
coxexp <- (exp(beta %*% t(X[, -1]))) |> drop()
} else if(model == 'aft') {
coxexp <- rep(1, length(Time))
}
for(i in seq_along(death_point)){
temp <- sum((Time >= death_point[i]) * w * coxexp)
lambda[i] <- sum(Status * (Time == death_point[i])) / temp
}
HHazard <- numeric(length(Time))
for(i in seq_along(Time)){
HHazard[i] <- sum((Time[i] >= death_point) * lambda)
}
HHazard[ Time > max(death_point) ] <- Inf
HHazard[ Time < min(death_point) ] <- 0
survival <- exp(-HHazard)
list(survival = survival)
}
nr_obs = 50000
Time_input <- rnorm(nr_obs, mean = 100, sd = 36)
Status_input <- sample(c(0,1), replace=TRUE, size=nr_obs)
w_input <- Status_input
# Let's suppose there are 9 variables (first column denotes the intercept)
n_variables <- 9
X_input <- matrix(rnorm(nr_obs*n_variables),nr_obs)
X_input <- cbind(Intercept = rep(1, nrow(X_input)), X_input)
beta_input <- runif(n_variables, min = -1, max = 1)
model_input <- "ph"
system.time(
output <- smsurv(Time_input,Status_input,X_input,beta_input,w_input,model_input)
)
#> user system elapsed
#> 25.08 4.95 33.39
system.time(
output2 <- smsurv2(Time_input,Status_input,X_input,beta_input,w_input,model_input)
)
#> user system elapsed
#> 16.92 1.45 19.86
identical(output, output2)
#> [1] TRUE
创建于 2023-09-30,使用 reprex v2.0.2