我想使用多个估算数据的多项回归来获得分类变量每个级别的汇总平均预测。我的问题是
pool()
函数似乎会折叠所有级别以产生一个估计的平均预测,而不是保留每个类别的分层。这是一个分步示例:
rent
的 library(catdata)
数据,我创建了 2 个分类变量library("tidyverse")
library("catdata")
data(rent)
rent<-rent%>%
mutate(rent_cat=factor(case_when(
rent<389.95~"Low price",
rent>=389.95&rent<700.48~"Medium price",
rent>=700.48~"High price"
)))%>%
mutate(size_cat=factor(case_when(
size<53~"Small",
size>=53&size<83~"Medium",
size>=83~"Big"
)))
micerent<-mice(rent,m=3)
需要注意的是,基于这篇文章,最好构建一个函数来提取预测概率并将其以完整的格式应用于我们的估算数据集。我正在使用
avg_predictions
以下来自 marignaleffects
包教程的建议。
fit_reg <- function(dat) {
mod <- multinom(rent_cat~ size_cat,data=dat)
out <- avg_predictions(mod,type="probs",by="size_cat")
return(out)
}
micecompleterent<- complete(micerent, "all")
model<-lapply(micecompleterent, fit_reg)
在那里,它似乎不是按
size_cat
级别生成平均预测估计,而是折叠所有分层平均值以给出一个汇总平均预测。
summary(pool(model),conf.int=T)
我还尝试在
newdata=datagrid(size_cat=c("Big","Medium","Small")
包中的 predictions
函数中使用 marginaleffects
,但结果是相同的:
fit_reg_response <- function(dat) {
mod <- multinom(c19a~ sexatbirth + age_cat,subset=(time_num=="0"),data=dat)
out <- predictions(mod,newdata=datagrid(size_cat=c("Big","Medium","Small")))
return(out)
}
有两个问题:
mice::pool
在所有对象上调用 tidy()
函数。它期望 tidy()
的输出是一个带有 term
列的 data.frame,但是在这种情况下 tidy
不会返回这样的列,因为没有一个健壮且通用的方法来说明什么是“术语” ”是在平均预测的背景下。tidy()
的输出时,
avg_predictions()
没有超级简单的方法来解决这个问题,但这里有一个解决方法:
fit_reg()
函数的输出分配自定义类名称。tidy()
S3 方法,根据您关心的行标识符创建您想要的 term
列。library("tidyverse")
library("catdata")
library("mice")
library("nnet")
library("marginaleffects")
data(rent)
rent<-rent%>%
mutate(rent_cat=factor(case_when(
rent<389.95~"Low price",
rent>=389.95&rent<700.48~"Medium price",
rent>=700.48~"High price"
)))%>%
mutate(size_cat=factor(case_when(
size<53~"Small",
size>=53&size<83~"Medium",
size>=83~"Big"
)))
micerent <- mice(rent, m = 3, seed = 1024)
fit_reg <- function(dat) {
mod <- multinom(rent_cat ~ size_cat, data = dat, trace = FALSE)
out <- avg_predictions(mod, type = "probs", by = "size_cat")
# the next line is key
class(out) <- c("custom_class", class(out))
return(out)
}
micecompleterent <- complete(micerent, "all")
model <- lapply(micecompleterent, fit_reg)
tidy.custom_class <- function(x, ...) {
# create a row identifier
transform(x, term = paste0(group, size_cat))
}
summary(pool(model), conf.int = TRUE)
# term estimate std.error statistic df
# 1 High priceBig 6.131852e-01 0.0211548249 28.98559586 2041.778
# 2 High priceMedium 1.867514e-01 0.0122504868 15.24440754 2041.778
# 3 High priceSmall 2.399958e-06 0.0000685316 0.03501973 2041.778
# 4 Low priceBig 5.283582e-02 0.0097171507 5.43737813 2041.778
# 5 Low priceMedium 1.650216e-01 0.0116685744 14.14239902 2041.778
# 6 Low priceSmall 6.223208e-01 0.0214465891 29.01723885 2041.778
# 7 Medium priceBig 3.339790e-01 0.0204863976 16.30247432 2041.778
# 8 Medium priceMedium 6.482270e-01 0.0150108250 43.18396572 2041.778
# 9 Medium priceSmall 3.776768e-01 0.0214465633 17.61013157 2041.778
# p.value 2.5 % 97.5 %
# 1 5.135684e-155 0.5716979181 0.6546724948
# 2 8.872206e-50 0.1627266590 0.2107761683
# 3 9.720674e-01 -0.0001319992 0.0001367991
# 4 6.054466e-08 0.0337792608 0.0718923850
# 5 2.168686e-43 0.1421380844 0.1879051861
# 6 2.679125e-155 0.5802613238 0.6643802735
# 7 2.898326e-56 0.2938025529 0.3741553884
# 8 5.330330e-290 0.6187888240 0.6776650781
# 9 9.846943e-65 0.3356173772 0.4197362256