R 中每个类别的 SHAP 特征重要性图

问题描述 投票:0回答:1

我将 SHAP 模型应用于我的随机森林多类分类模型。

有没有办法获得:

  1. 不是用 8 个不同的图(图 1)来表示 Y 变量的类别,而是像图 2 那样使用组合图?

  2. 只有一个全局平均特征重要性图?

这是我的代码:

library(randomForest)
library(kernelshap)
library(shapviz)

RF <- randomForest(
  droughts_column ~ .,
  data = train_data,
  ntree=100, 
  mtry= 17,
  importance = TRUE
)

s <- kernelshap(RF, X = train_data[, -1], bg_X = bg_X, type = "prob") # if i don't set type= "prob", I run through "Predictions must be numeric".

sv <- shapviz(s)

sv_importance(sv, kind = "bar", max_display = 10)

这是我的数据:

dput(train_data[c(1:20), 1:31])
structure(list(min_column = structure(c(4L, 5L, 4L, 7L, 7L, 5L, 
8L, 5L, 7L, 5L, 6L, 4L, 4L, 7L, 7L, 4L, 1L, 5L, 8L, 8L), levels = c("PREC3", 
"PPSTV", "PPSV", "PREC6", "SM3", "SV", "TWSA3", "VPD3"), class = "factor"), 
    aridity_index = c(1.05540223890365, 1.42291223741047, 0.289312012765131, 
    2.08955832504966, 1.29389651632679, 0.327431291000845, 3.31130880723198, 
    0.435962678062337, 0.720398000663778, 0.317880591943831, 
    1.25169086405035, 0.686319667529413, 1.40184251502531, 1.35417639189355, 
    0.640902793146427, 1.38923272127388, 0.48987851172821, 0.559316314968663, 
    0.317328910956498, 0.34518109028685), mean_t2m = c(25.8957153320313, 
    11.9346964518229, 29.8120997111003, 27.2445922851563, 26.2609522501628, 
    27.6344212849935, 25.1281499226888, 19.8798919677735, 17.526274617513, 
    31.3262395222982, 20.2956227620443, 23.411154683431, 20.2042404174805, 
    23.607448832194, 17.7585464477539, 25.0856745402018, 23.8563552856446, 
    22.4191426595052, 22.3183787027995, 23.6592041015625), sd_t2m = c(0.748837584496715, 
    1.69355443757821, 1.94353385216372, 0.393817169978326, 0.437997449158287, 
    1.20966419549391, 0.30109030223568, 1.46221521450813, 1.20757543363649, 
    0.384620838283781, 0.951142946899067, 0.660012290331394, 
    1.05941750720462, 0.623493914695329, 1.66694863819749, 0.828680288883452, 
    0.9433044402992, 1.24476846679161, 1.37801589084397, 0.746930915488626
    ), mean_vpd = c(10.6565330028534, 3.94749279816945, 29.6649556954702, 
    8.2109356323878, 7.93534930547078, 17.9001136620839, 4.93418568372726, 
    14.3897955814997, 6.24165391921997, 31.4230984052022, 8.02857120831808, 
    8.20808569590251, 7.72910133997599, 8.53423659006754, 8.67752106984456, 
    5.9931894938151, 11.9273592233658, 16.4530976613363, 15.9718067646027, 
    14.2948899269104), sd_vpd = c(1.73944337077899, 0.755394420255983, 
    8.05402389935442, 1.10117851181646, 0.948786864039974, 3.09593279514218, 
    0.610038282171325, 3.45080704918886, 1.19894214225847, 3.17718368822484, 
    1.50526167566968, 1.25732973567913, 1.22045801441578, 1.21170835192973, 
    2.30159243776068, 1.483610883746, 2.40324777669803, 2.78005896417215, 
    4.37614864431736, 1.7897095039182), mean_prec = c(92.4068752924601, 
    84.150003751119, 47.7314598560333, 118.624369303385, 259.424163818359, 
    29.1541662439704, 222.893957773844, 33.1908336480459, 107.695629437764, 
    7.0000001937151, 133.085209528605, 118.405418713888, 113.004998207092, 
    109.223124663035, 45.3147913614909, 241.841251373291, 116.099378267924, 
    47.0570815404256, 32.3000017683953, 18.1831246614456), sd_prec = c(25.2230387606066, 
    23.3214254568287, 54.2045520974117, 54.0742759998264, 84.2605910448614, 
    44.5100269137947, 61.7197090730034, 27.3946301979818, 39.1184086642315, 
    5.5819938843692, 118.01624202547, 66.4338077887344, 48.0153506259032, 
    62.4349963152462, 35.9822947596737, 177.032327827943, 40.7720457797305, 
    41.9353214009905, 48.2970552400401, 11.0728449402471), mean_sm = c(0.280448893706004, 
    0.202963148554166, 0.138948903108637, 0.352660410106182, 
    0.278219901025295, 0.173655264079571, 0.311909670631091, 
    0.220840279012918, 0.26629921918114, 0.145378011589249, 0.289096682022015, 
    0.238765094429255, 0.293268837034702, 0.352709278464317, 
    0.245920985937119, 0.299343595902125, 0.243301281084617, 
    0.14155216080447, 0.188165209566553, 0.230367512752612), 
    sd_sm = c(0.0120999081770558, 0.0106934897216381, 0.0449606306723597, 
    0.00935333257513591, 0.0170730820627545, 0.0287503551492261, 
    0.00551166949866713, 0.0263233458019688, 0.0224377766941138, 
    0.0231998024046729, 0.0261685416345269, 0.0252500867384453, 
    0.0133047644235911, 0.0121125656223249, 0.020165065621216, 
    0.0105683334975659, 0.0268471729517476, 0.0209773792143915, 
    0.0484241248619342, 0.0123767393717239), mean_snr = c(12.0064274470011, 
    6.76889745394389, 9.40070144335429, 12.9343137741089, 12.6970313390096, 
    9.85343869527181, 11.4195901552836, 12.5402154922485, 9.32121702035268, 
    7.51378111044566, 9.33259244759878, 13.1668708324432, 10.5996281305949, 
    8.06930311520894, 8.90324242909749, 11.8287103176117, 12.5012144247691, 
    16.2886624336243, 8.91103057066599, 7.04941769440969), sd_snr = c(1.14155839883206, 
    0.892173051553268, 0.654157370607644, 0.827404584595355, 
    0.790501626018555, 0.403225009907126, 0.515157191337593, 
    0.378115825721952, 1.05029652548267, 0.497559161125383, 0.667521104274022, 
    0.719427400607346, 0.58177946280854, 0.736912015197759, 0.751189102845447, 
    1.54676429427932, 0.579172898464673, 0.581649429155848, 0.757825017522055, 
    0.66237735678821), grass.man = c(16.052395277553, 0, 0, 0.66145833317811, 
    1.43209873843522, 15.3672170920507, 0.00144675930237593, 
    70.6226110458368, 0.486689811494829, 12.669270892938, 6.96614590287209, 
    35.5963554183634, 4.59172460436825, 77.6109197404646, 97.3732637829251, 
    29.6277959677904, 31.8275464773165, 89.6716828346252, 1.30208332091568, 
    0.0188078705105002), grass.nat = c(69.4744475682573, 49.188658979204, 
    16.0431131124497, 50.8833920293391, 19.497588634491, 60.9923440615334, 
    0.12577160423906, 22.4525263309482, 53.1556712786355, 80.8637145360311, 
    91.4583339691162, 38.05237253507, 13.8398921489716, 8.84162787596396, 
    2.06963733335332, 17.9012338386641, 37.8492470979698, 5.67679386999876, 
    86.2560768127439, 82.8998835881549), slope = c(2.16054304838179, 
    1.97955802321434, 0.459141999483106, 4.04891811609252, 1.37888130545614, 
    2.56013614609837, 2.42401843070985, 5.54462247371674, 2.90989639997484, 
    0.852418287396432, 1.59599373936653, 2.52898863553996, 6.2659276509285, 
    1.76413418650627, 0.292095533311367, 1.12207695484161, 1.18567454397679, 
    0.569621352553369, 1.84543817639354, 3.72511589765552), elevation = c(359.97704925537, 
    398.722771911621, 426.586759948729, 113.829170227046, 221.654676361083, 
    118.164677486419, 193.076432800294, 1278.94374267578, 366.954654235841, 
    310.976059875488, 115.406995010376, 824.692267456063, 452.296558532715, 
    24.027112817764, 120.484393692017, 328.746539916992, 1225.43281127929, 
    187.192862854004, 1294.85421752929, 1099.986953125), irrigated_pct = c(1.25787228625548, 
    0, 0, 0.00231511534741727, 0.000534317362681038, 10.66306531353, 
    0, 5.93849432875218, 0.00258188543729491, 0, 0.132554369636341, 
    0, 0.0692080791329039, 34.1391748461725, 0.201742927556814, 
    0.0077405653046329, 0.00480382097959415, 0.00376549045024443, 
    0.0321618323922187, 0.0455722879171372), T_SILT = c(6.45358939186102, 
    23.0251486053956, 5.855814662399, 33.130658436214, 5.71982167352536, 
    22.0589468068892, 13.9091220850479, 36.8535665294925, 33.0751028806584, 
    22.6529492455417, 44.4293552812071, 10.7534293552813, 39, 
    36.1208276177412, 40.1746684956561, 14.0314357567444, 29.253886602652, 
    35.6460143270843, 14.7945816186557, 28.5586419753087), T_SAND = c(77.6197226032616, 
    65.1125971650663, 83.4992379210485, 31.9764136564551, 71.943072702332, 
    27.0651958542905, 55.8758573388203, 39.8782578875171, 43.9490169181527, 
    61.8762002743486, 23.377914951989, 19.1499771376313, 41, 
    36.0430574607529, 30.3878600823045, 60.5447721383936, 47.3446502057613, 
    26.8314281359549, 71.0716735253772, 41.3868312757202), T_CLAY = c(15.9266880048773, 
    10.3419067215364, 6.73929279073309, 34.8929279073309, 22.3371056241426, 
    50.8758573388203, 30.2150205761318, 22.10219478738, 22.9758802011889, 
    15.4708504801097, 32.1927297668039, 70.0965935070874, 20, 
    19.7199740893156, 29.4374714220393, 25.423792104862, 23.4014631915867, 
    37.5225575369608, 14.133744855967, 30.0545267489712), fraction_tree = c(8.75704383786053, 
    25.7484571668838, 0, 41.6792055434643, 71.8004441261293, 
    6.29989361806966, 98.8653546733872, 0.887232535285893, 44.7146998239333, 
    0.101562502017867, 1.24826386972564, 17.082176297903, 79.1825804582072, 
    2.98678628670882, 0.126157407784675, 49.9505215655604, 9.95341438055052, 
    0.0438850315468813, 2.35069438815109, 0.080439814832063), 
    fraction_shrubs = c(0.150779876418225, 4.97656261755356, 
    0.00617283955216385, 1.91213345651827, 7.23668978611617, 
    8.4761317869028, 0.00125385803403324, 0.0193571608397205, 
    0.0925925932824588, 0.0286458334497496, 0.0703125018626452, 
    9.25520811975072, 0, 0, 0.0070408952111995, 2.20023143870961, 
    20.3544558684036, 1.56240353929916, 9.99826425313981, 16.9250582066212
    ), Artificial_Surfaces = c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0), CropLand = c(1, 0, 0, 0, 0, 
    0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0), Grassland = c(0, 
    0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1), 
    Tree_Covered_Area = c(0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 
    1, 0, 0, 1, 1, 0, 0, 0), Shrubs_Covered_Area = c(0, 0, 0, 
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0), Herbaceous_aquatic_regularly_flooded = c(0, 
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), 
    Mangroves = c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
    0, 0, 0, 0, 0), Sparse_vegetation = c(0, 0, 1, 0, 0, 0, 0, 
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), BareSoil = c(0, 0, 
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)), class = c("data.table", 
"data.frame"), row.names = c(NA, -20L), .internal.selfref = <pointer: 0x560853c37a90>)
r plot random-forest shap
1个回答
0
投票

正如每封电子邮件中所讨论的,在概率尺度上总结平均绝对 SHA 值可能不是最好的事情。

不过,你可以这样做:

library(ranger)
library(kernelshap)
library(shapviz)
library(tidyverse)

RF <- ranger(
  Species ~ .,
  data = iris,
  num.trees = 100,
  probability = TRUE
)

s <- kernelshap(RF, X = iris[, -5], bg_X = iris) 
sv <- shapviz(s)

(imp <- as.data.frame(sv_importance(sv, kind = "no")))
#                  setosa  versicolor   virginica
# Petal.Length 0.23553917 0.208185406 0.201901984
# Petal.Width  0.18105394 0.193248091 0.191982496
# Sepal.Length 0.01866434 0.019473629 0.021358250
# Sepal.Width  0.01088245 0.009355846 0.006084506

imp_reshaped <- imp |> 
  rownames_to_column(var = "Variable") |> 
  pivot_longer(-Variable, names_to = "Class")

ggplot(imp_reshaped, aes(y = reorder(Variable, value), fill = Class, x = value)) +
  geom_bar(position = "stack", stat = "identity") +
  scale_fill_viridis_d(begin = 0.2, end = 0.8, option = "B") +
  labs(y = element_blank(), x = "Average absolute SHAP values")
```r

[![enter image description here][1]][1]


  [1]: https://i.stack.imgur.com/qfHAW.png
© www.soinside.com 2019 - 2024. All rights reserved.