如何在使用空间样本时导出随机森林回归的残差?

问题描述 投票:0回答:1

我正在尝试复制这个教程,但我使用的是随机森林(RF)而不是线性回归。我可以做出预测,但现在我想计算并提取回归的残差(即观察到的预测值)。然后,我想用

cbind
的坐标来
data.frame
残差,如下所示:

resids_df <- cbind(original_df[, 1:2], rf_resids) # where the original_df[, 1:2] contains the coordinates and the rf_resids are the residuals of the RF regression

问题是,在做出预测后,输出比我的原始数据集多了 54 个值(3792 与 3738)。这会导致一个问题,因为由于行数的差异,我无法

cbind
残差。

如何解决这个问题并获得与原始数据集完全相同数量的残差(观察值)?

附注我的数据集不包含 NA 值。

在下面的示例中,我使用了数据集的子集,但您再次可以看到,与数据集中的行数相比,预测中的值多了 1 个。

library(tidymodels)
library(spatialsample)
library(sf)

wd <- "path/"

proj_ref_sys <- "EPSG:7760"

drought <- read.csv(paste0(wd, "block.data.csv"))
nrow(drought)
# [1] 60 !!!!!!!!!!!!!

drought_sf <- st_as_sf(drought, coords = c("x", "y"),  crs = proj_ref_sys)

set.seed(123)
folds <- spatial_block_cv(drought_sf, v = 3)

drought_res <-
  workflow(ntl ~ pop + agbh + nir, 
           rand_forest(mode = "regression", mtry = 2, trees = 100) %>%
             set_engine("randomForest")) %>%
  fit_resamples(folds, control = control_resamples(save_pred = TRUE))

drought_res

collect_predictions(drought_res)

# A tibble: **61** × 5 !!!!!!!!!!!!!!!!!!!
   id    .pred  .row   ntl .config             
   <chr> <dbl> <int> <dbl> <chr>               
 1 Fold1 28.7     18  29.2 Preprocessor1_Model1
 2 Fold1 27.9     19  32.8 Preprocessor1_Model1
 3 Fold1 17.2     20  29.6 Preprocessor1_Model1
 4 Fold1 19.6     21  28.6 Preprocessor1_Model1
 5 Fold1 34.3     22  36.5 Preprocessor1_Model1
 6 Fold1 48.7     28  34.8 Preprocessor1_Model1
 7 Fold1 45.9     29  32.2 Preprocessor1_Model1
 8 Fold1 40.1     30  28.3 Preprocessor1_Model1
 9 Fold1 14.6     31  22.5 Preprocessor1_Model1
10 Fold1  9.96    32  17.1 Preprocessor1_Model1
# ℹ 51 more rows
# ℹ Use `print(n = ...)` to see more rows

我正在使用的

data.frame

structure(list(x = c(995494.2549, 995924.2549, 996354.2549, 996784.2549, 
997214.2549, 997644.2549, 998074.2549, 998504.2549, 998934.2549, 
999364.2549, 999794.2549, 1000224.2549, 1000654.2549, 1001084.2549, 
1001514.2549, 1001944.2549, 1002374.2549, 1002804.2549, 1003234.2549, 
1003664.2549, 1004094.2549, 1004524.2549, 1004954.2549, 1005384.2549, 
1005814.2549, 1006244.2549, 1006674.2549, 1007104.2549, 1007534.2549, 
1007964.2549, 1008394.2549, 1008824.2549, 1009254.2549, 1009684.2549, 
1010114.2549, 1010544.2549, 1010974.2549, 1011404.2549, 1011834.2549, 
1012264.2549, 1012694.2549, 1013124.2549, 1013554.2549, 1013984.2549, 
1014414.2549, 1014844.2549, 1015274.2549, 1015704.2549, 1016134.2549, 
1016564.2549, 1016994.2549, 1017424.2549, 1017854.2549, 1018284.2549, 
1018714.2549, 995494.2549, 995924.2549, 996354.2549, 996784.2549, 
997214.2549), y = c(1019851.5842, 1019851.5842, 1019851.5842, 
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
1019851.5842, 1019851.5842, 1019421.5842, 1019421.5842, 1019421.5842, 
1019421.5842, 1019421.5842), ntl = c(9.14866638183594, 15.3856477737427, 
16.3302040100098, 12.454291343689, 10.4823837280273, 11.394606590271, 
8.1963529586792, 4.50725030899048, 3.95374751091003, 5.73203563690186, 
14.3955335617065, 17.0745468139648, 14.2944135665894, 10.333722114563, 
9.80743503570557, 12.5352020263672, 19.8813304901123, 29.2410221099854, 
32.8321876525879, 29.575023651123, 28.5894374847412, 36.4911346435547, 
49.4252128601074, 61.3118171691895, 58.6104736328125, 43.0437355041504, 
28.096061706543, 34.8003845214844, 32.1936340332031, 28.3407783508301, 
22.5178966522217, 17.0638084411621, 20.7549228668213, 18.3547439575195, 
10.2983675003052, 7.3524694442749, 7.17788362503052, 7.06999540328979, 
8.03957176208496, 12.6783542633057, 18.7537479400635, 26.1656856536865, 
36.539493560791, 41.0569839477539, 25.5366401672363, 15.7820110321045, 
9.87918758392334, 7.65169858932495, 6.96318626403809, 8.69833087921143, 
12.1393032073975, 15.151198387146, 14.5944147109985, 9.46016979217529, 
4.53868055343628, 12.8388118743896, 21.1265335083008, 19.3046970367432, 
10.5719947814941, 8.08844661712646), pop = c(31.2753772735596, 
55.8289375305176, 56.4003105163574, 33.795223236084, 31.0511913299561, 
30.5730743408203, 13.667106628418, 7.08161020278931, 6.89333772659302, 
13.9001550674438, 35.5272178649902, 42.4625587463379, 32.9688529968262, 
21.4302787780762, 12.6151924133301, 17.4939270019531, 38.1474113464355, 
60.8120536804199, 65.3665008544922, 53.8765907287598, 46.2705993652344, 
61.42333984375, 70.8307113647461, 53.3152236938477, 31.4083557128906, 
24.9810562133789, 38.3716621398926, 56.114860534668, 67.1656036376953, 
60.8404235839844, 33.7796592712402, 29.8311328887939, 44.3309173583984, 
31.9606342315674, 16.7053775787354, 10.1427822113037, 11.4020376205444, 
10.7794933319092, 18.2773151397705, 34.2912216186523, 50.6655197143555, 
52.1081962585449, 53.0502471923828, 59.4989013671875, 48.5897750854492, 
41.188159942627, 27.0699615478516, 11.5318984985352, 9.09538650512695, 
14.2379903793335, 24.8153190612793, 29.3468627929688, 30.5861835479736, 
15.3130531311035, 9.47307205200195, 37.2332077026367, 94.2268676757812, 
73.2485733032227, 26.8748569488525, 26.8519401550293), agbh = c(0.124395661056042, 
0.543155550956726, 0.930405616760254, 0.176615670323372, 0.122252210974693, 
1.86410081386566, 0.201039269566536, 0.00215102708898485, 0.00524011626839638, 
0.0221506990492344, 1.75632297992706, 0.954743504524231, 0.373224049806595, 
0.0127956680953503, 0.0007417316082865, 0.0123716788366437, 0.279229581356049, 
2.30779552459717, 2.58910322189331, 1.23243260383606, 0.819948613643646, 
1.74025285243988, 4.03071403503418, 2.78268098831177, 2.00978517532349, 
0.700970351696014, 0.196071043610573, 2.19463133811951, 4.83159875869751, 
2.20620393753052, 0.321354597806931, 0.00308413081802428, 1.737912774086, 
0.468539208173752, 0.0156131321564317, 0.00116395147051662, 0.0145542966201901, 
0.000892410753294826, 0.0419198162853718, 2.84171080589294, 3.22121715545654, 
2.73401832580566, 2.47091150283813, 2.10038590431213, 1.15651941299438, 
0.490403175354004, 0.0419915802776814, 0.101970501244068, 0.00181114906445146, 
0.0132269319146872, 0.212756171822548, 0.111757233738899, 1.2169703245163, 
0.129767879843712, 0, 0.582266986370087, 2.96843385696411, 1.16728830337524, 
0.0494964420795441, 0.0664984136819839), nir = c(0.261590600013733, 
0.250058531761169, 0.238313049077988, 0.246726274490356, 0.241509333252907, 
0.215491861104965, 0.25552836060524, 0.26755028963089, 0.283316373825073, 
0.2645283639431, 0.2347122579813, 0.250579416751862, 0.272739976644516, 
0.26601967215538, 0.260071456432343, 0.283827364444733, 0.270996034145355, 
0.229571804404259, 0.228905484080315, 0.240774929523468, 0.22843000292778, 
0.201068416237831, 0.174168020486832, 0.187955036759377, 0.235188364982605, 
0.226306527853012, 0.197943985462189, 0.192345812916756, 0.18694880604744, 
0.203041225671768, 0.24348683655262, 0.264572501182556, 0.234625786542892, 
0.252681404352188, 0.252072751522064, 0.241365790367126, 0.228045880794525, 
0.252986639738083, 0.261032313108444, 0.233464851975441, 0.235829710960388, 
0.235184907913208, 0.212146639823914, 0.204127430915833, 0.216947212815285, 
0.225598230957985, 0.231632620096207, 0.224976778030396, 0.219116434454918, 
0.255260914564133, 0.241265594959259, 0.237798929214478, 0.241482153534889, 
0.240964710712433, 0.252938002347946, 0.258243441581726, 0.211435839533806, 
0.217503502964973, 0.237074509263039, 0.237700119614601)), row.names = c(NA, 
60L), class = "data.frame")
r random-forest tidymodels
1个回答
0
投票

由于您的数据间隔规则,因此某些网格单元格边框与您的数据完美对齐。在底层,spatialsample 使用

sf::st_intersects()
来确定哪些观测值与哪些网格块相交。当观察完全位于单元格边界时,它最终会出现在两个块的折叠中。

解决方法是稍微偏移网格,使网格线不会与您的观察结果完全对齐。您可以通过

offset
st_make_grid()
参数来控制它,我们可以通过
...
中的
spatial_block_cv()
传递该参数。

例如,使用您的干旱数据:

drought <- structure(list(x = c(995494.2549, 995924.2549, 996354.2549, 996784.2549, 
                                997214.2549, 997644.2549, 998074.2549, 998504.2549, 998934.2549, 
                                999364.2549, 999794.2549, 1000224.2549, 1000654.2549, 1001084.2549, 
                                1001514.2549, 1001944.2549, 1002374.2549, 1002804.2549, 1003234.2549, 
                                1003664.2549, 1004094.2549, 1004524.2549, 1004954.2549, 1005384.2549, 
                                1005814.2549, 1006244.2549, 1006674.2549, 1007104.2549, 1007534.2549, 
                                1007964.2549, 1008394.2549, 1008824.2549, 1009254.2549, 1009684.2549, 
                                1010114.2549, 1010544.2549, 1010974.2549, 1011404.2549, 1011834.2549, 
                                1012264.2549, 1012694.2549, 1013124.2549, 1013554.2549, 1013984.2549, 
                                1014414.2549, 1014844.2549, 1015274.2549, 1015704.2549, 1016134.2549, 
                                1016564.2549, 1016994.2549, 1017424.2549, 1017854.2549, 1018284.2549, 
                                1018714.2549, 995494.2549, 995924.2549, 996354.2549, 996784.2549, 
                                997214.2549), y = c(1019851.5842, 1019851.5842, 1019851.5842, 
                                                    1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
                                                    1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
                                                    1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
                                                    1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
                                                    1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
                                                    1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
                                                    1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
                                                    1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
                                                    1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
                                                    1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 1019851.5842, 
                                                    1019851.5842, 1019851.5842, 1019421.5842, 1019421.5842, 1019421.5842, 
                                                    1019421.5842, 1019421.5842), ntl = c(9.14866638183594, 15.3856477737427, 
                                                                                         16.3302040100098, 12.454291343689, 10.4823837280273, 11.394606590271, 
                                                                                         8.1963529586792, 4.50725030899048, 3.95374751091003, 5.73203563690186, 
                                                                                         14.3955335617065, 17.0745468139648, 14.2944135665894, 10.333722114563, 
                                                                                         9.80743503570557, 12.5352020263672, 19.8813304901123, 29.2410221099854, 
                                                                                         32.8321876525879, 29.575023651123, 28.5894374847412, 36.4911346435547, 
                                                                                         49.4252128601074, 61.3118171691895, 58.6104736328125, 43.0437355041504, 
                                                                                         28.096061706543, 34.8003845214844, 32.1936340332031, 28.3407783508301, 
                                                                                         22.5178966522217, 17.0638084411621, 20.7549228668213, 18.3547439575195, 
                                                                                         10.2983675003052, 7.3524694442749, 7.17788362503052, 7.06999540328979, 
                                                                                         8.03957176208496, 12.6783542633057, 18.7537479400635, 26.1656856536865, 
                                                                                         36.539493560791, 41.0569839477539, 25.5366401672363, 15.7820110321045, 
                                                                                         9.87918758392334, 7.65169858932495, 6.96318626403809, 8.69833087921143, 
                                                                                         12.1393032073975, 15.151198387146, 14.5944147109985, 9.46016979217529, 
                                                                                         4.53868055343628, 12.8388118743896, 21.1265335083008, 19.3046970367432, 
                                                                                         10.5719947814941, 8.08844661712646), pop = c(31.2753772735596, 
                                                                                                                                      55.8289375305176, 56.4003105163574, 33.795223236084, 31.0511913299561, 
                                                                                                                                      30.5730743408203, 13.667106628418, 7.08161020278931, 6.89333772659302, 
                                                                                                                                      13.9001550674438, 35.5272178649902, 42.4625587463379, 32.9688529968262, 
                                                                                                                                      21.4302787780762, 12.6151924133301, 17.4939270019531, 38.1474113464355, 
                                                                                                                                      60.8120536804199, 65.3665008544922, 53.8765907287598, 46.2705993652344, 
                                                                                                                                      61.42333984375, 70.8307113647461, 53.3152236938477, 31.4083557128906, 
                                                                                                                                      24.9810562133789, 38.3716621398926, 56.114860534668, 67.1656036376953, 
                                                                                                                                      60.8404235839844, 33.7796592712402, 29.8311328887939, 44.3309173583984, 
                                                                                                                                      31.9606342315674, 16.7053775787354, 10.1427822113037, 11.4020376205444, 
                                                                                                                                      10.7794933319092, 18.2773151397705, 34.2912216186523, 50.6655197143555, 
                                                                                                                                      52.1081962585449, 53.0502471923828, 59.4989013671875, 48.5897750854492, 
                                                                                                                                      41.188159942627, 27.0699615478516, 11.5318984985352, 9.09538650512695, 
                                                                                                                                      14.2379903793335, 24.8153190612793, 29.3468627929688, 30.5861835479736, 
                                                                                                                                      15.3130531311035, 9.47307205200195, 37.2332077026367, 94.2268676757812, 
                                                                                                                                      73.2485733032227, 26.8748569488525, 26.8519401550293), agbh = c(0.124395661056042, 
                                                                                                                                                                                                      0.543155550956726, 0.930405616760254, 0.176615670323372, 0.122252210974693, 
                                                                                                                                                                                                      1.86410081386566, 0.201039269566536, 0.00215102708898485, 0.00524011626839638, 
                                                                                                                                                                                                      0.0221506990492344, 1.75632297992706, 0.954743504524231, 0.373224049806595, 
                                                                                                                                                                                                      0.0127956680953503, 0.0007417316082865, 0.0123716788366437, 0.279229581356049, 
                                                                                                                                                                                                      2.30779552459717, 2.58910322189331, 1.23243260383606, 0.819948613643646, 
                                                                                                                                                                                                      1.74025285243988, 4.03071403503418, 2.78268098831177, 2.00978517532349, 
                                                                                                                                                                                                      0.700970351696014, 0.196071043610573, 2.19463133811951, 4.83159875869751, 
                                                                                                                                                                                                      2.20620393753052, 0.321354597806931, 0.00308413081802428, 1.737912774086, 
                                                                                                                                                                                                      0.468539208173752, 0.0156131321564317, 0.00116395147051662, 0.0145542966201901, 
                                                                                                                                                                                                      0.000892410753294826, 0.0419198162853718, 2.84171080589294, 3.22121715545654, 
                                                                                                                                                                                                      2.73401832580566, 2.47091150283813, 2.10038590431213, 1.15651941299438, 
                                                                                                                                                                                                      0.490403175354004, 0.0419915802776814, 0.101970501244068, 0.00181114906445146, 
                                                                                                                                                                                                      0.0132269319146872, 0.212756171822548, 0.111757233738899, 1.2169703245163, 
                                                                                                                                                                                                      0.129767879843712, 0, 0.582266986370087, 2.96843385696411, 1.16728830337524, 
                                                                                                                                                                                                      0.0494964420795441, 0.0664984136819839), nir = c(0.261590600013733, 
                                                                                                                                                                                                                                                       0.250058531761169, 0.238313049077988, 0.246726274490356, 0.241509333252907, 
                                                                                                                                                                                                                                                       0.215491861104965, 0.25552836060524, 0.26755028963089, 0.283316373825073, 
                                                                                                                                                                                                                                                       0.2645283639431, 0.2347122579813, 0.250579416751862, 0.272739976644516, 
                                                                                                                                                                                                                                                       0.26601967215538, 0.260071456432343, 0.283827364444733, 0.270996034145355, 
                                                                                                                                                                                                                                                       0.229571804404259, 0.228905484080315, 0.240774929523468, 0.22843000292778, 
                                                                                                                                                                                                                                                       0.201068416237831, 0.174168020486832, 0.187955036759377, 0.235188364982605, 
                                                                                                                                                                                                                                                       0.226306527853012, 0.197943985462189, 0.192345812916756, 0.18694880604744, 
                                                                                                                                                                                                                                                       0.203041225671768, 0.24348683655262, 0.264572501182556, 0.234625786542892, 
                                                                                                                                                                                                                                                       0.252681404352188, 0.252072751522064, 0.241365790367126, 0.228045880794525, 
                                                                                                                                                                                                                                                       0.252986639738083, 0.261032313108444, 0.233464851975441, 0.235829710960388, 
                                                                                                                                                                                                                                                       0.235184907913208, 0.212146639823914, 0.204127430915833, 0.216947212815285, 
                                                                                                                                                                                                                                                       0.225598230957985, 0.231632620096207, 0.224976778030396, 0.219116434454918, 
                                                                                                                                                                                                                                                       0.255260914564133, 0.241265594959259, 0.237798929214478, 0.241482153534889, 
                                                                                                                                                                                                                                                       0.240964710712433, 0.252938002347946, 0.258243441581726, 0.211435839533806, 
                                                                                                                                                                                                                                                       0.217503502964973, 0.237074509263039, 0.237700119614601)), row.names = c(NA, 
                                                                                                                                                                                                                                                                                                                                60L), class = "data.frame")

我们可以将网格移动一米,仍然可以在其中一个折叠中获得所有观察结果,但不会重复任何观察结果:

library(tidymodels)
library(spatialsample)
library(sf)
#> Linking to GEOS 3.11.1, GDAL 3.6.4, PROJ 9.1.1; sf_use_s2() is TRUE
proj_ref_sys <- "EPSG:7760"
drought_sf <- st_as_sf(drought, coords = c("x", "y"),  crs = proj_ref_sys)

set.seed(123)
# default: 61 observations in assessment
folds <- spatial_block_cv(drought_sf, v = 3)
vapply(
  seq_len(nrow(folds)), 
  function(i) nrow(assessment(get_rsplit(folds, i))),
  numeric(1)
) |> 
  sum()
#> [1] 61

set.seed(123)
# With a tiny offset: 60 observations as we'd expect
folds <- spatial_block_cv(
  drought_sf, 
  v = 3,
  # This is the change: move our grid by 1 meter
  offset = st_bbox(drought_sf)[c("xmin", "ymin")] - 1
)
vapply(
  seq_len(nrow(folds)), 
  function(i) nrow(assessment(get_rsplit(folds, i))),
  numeric(1)
) |> 
  sum()
#> [1] 60

创建于 2023-10-31,使用 reprex v2.0.2

这不太好,因为这是一个无声的问题,我将在一分钟内在 Spatialsample 中打开一个关于此问题的错误 - 但我猜测我实施的第一个修复是这种情况会导致错误(建议您传递显式的

offset
),因为我认为自动为任何用户提供的数据找到一个足够大的偏移量来解决这个问题而不意外排除观察结果并不是微不足道的。

© www.soinside.com 2019 - 2024. All rights reserved.