使用 rpart 和rattle 为 R 中的四类决策树创建更清晰的可视化

问题描述 投票:0回答:1

在提供的 R 代码中,使用 rpart 和 caret 包生成决策树,并使用 Rattle 包进行可视化。生成的图显示了一个四类决策树,但由于复杂性,解释起来很困难。我寻求建议,通过为每个类别创建单独的树来提高清晰度,从而实现更加集中和易于理解的演示。

library(rpart)
library(caret)

fitControl <- trainControl(method = "repeatedcv",
                           number = 10,
                           repeats = 10)

classifier = train(x = training_set[, names(training_set) != "Target"],
                   y = training_set$Target,
                   method = 'rpart',
                   parms = list(split = "gini"), trControl = fitControl,
                   tuneLength = 20)
classifier
complexity_parameter=classifier$bestTune

classifier = rpart(formula = Target ~ .,
                   data = training_set,parms = list(split = "information"),
                   control = rpart.control(cp = complexity_parameter))


library(RColorBrewer)
library(rattle)
fancyRpartPlot(classifier, caption = NULL, clip.right.labs=FALSE,branch=.3,type=3,
               tweak=1.4)

我有一个四类决策树,如附图所示,可视化非常复杂且难以解读。我正在考虑绘制四棵单独的树,每棵树专用于一个类(例如,第一棵树仅显示第一类的叶子,第二棵树用于第二类,依此类推)。我将不胜感激任何关于如何实现这一目标并提高决策树图的可解释性的指导或代码片段

编辑。这是训练数据集的示例:

training_set<- structure(list(AGE67CIYes = c(-0.176152387930331, -0.987016328202176, 
0.05552302357591, -0.58468873762319, 0.162742606800352, 0.120896778307869, 
-0.987016328202176, -0.64359160055763, -0.987016328202176, -0.628598979432629, 
-0.987016328202176, -0.307993987241449, 0.889554504998379, -0.987016328202176, 
-0.84077646366108, -0.122806076070342, 0.347797256654688, -0.585917218815798, 
3.27330664446935, -0.227710210722183, -0.987016328202176, 0.0907486763211531, 
-0.468831088265139, -0.0852317172820009, 2.14649177996699, -0.21957742854859, 
-0.478526166947832, -0.987016328202176, 0.856614535142944, -0.987016328202176, 
0.233732369261435, 0.773841021986012, 1.76557884040399, 1.70409677446699, 
0.177204736845891, -0.987016328202176, -0.0303666544722618, 0.267016686824448, 
-0.987016328202176, 0.133602302476064, -0.780150252101327, 0.569019137931335, 
0.54169474123801, 1.04432323350976, -0.00304660292847676, 0.595633772087449, 
0.0119187160870928, -0.987016328202176, 0.445153602815489, -0.0754726273166524, 
-0.0553854181026097, -0.987016328202176, 0.447212111288207, -0.412974267062895, 
0.565701855297101, 0.0332551927612325, 0.61493438306659, -0.987016328202176, 
-0.245916422283762, 0.0936414642259315, 0.217200249252726, -0.774974426145616, 
2.01102787070915, 0.644784396320045, 1.31792583076954, 0.0693891516694634, 
0.152242180258608, 1.09469958100705, 1.13548440454805, -0.158246875053666, 
0.755736021038804, 0.672766538708062, -0.735850059896174, -0.987016328202176, 
0.145397105625745, 1.03910128090896, -0.987016328202176, -0.987016328202176, 
-0.404673257734619, 0.215913693080231, 0.388480617599278, 0.411918265238067, 
-0.987016328202176, 0.113253215693915, 0.334391574463053, -0.558854795203353, 
1.27293994403935, 0.429900076191951, -0.535607536710634, 1.15212162607829, 
1.49001293895707, -0.987016328202176, 0.886209458949893, -0.11303124287923, 
3.37348021367463, 0.735737223588497, -0.0476125737990377, -0.14765557213803, 
-0.481299453037445, -0.0487309116018985), AGE18_34M = c(-0.311288013988282, 
-0.920252005474412, 0.839075436208503, 0.475468979668509, 0.361813050743929, 
-0.0880491662273217, -1.30060058115752, 0.2381797045763, 0.46997486395285, 
0.0951842010198588, -0.249345366669994, 0.0276151129841655, -0.679404318051254, 
-0.460158578131396, -0.0580351642511107, -0.4914107121112, 0.0350108362887032, 
-0.726823880120434, -1.21665514019721, -0.665041398673074, 2.03127028459426, 
0.173414240816047, -1.06412941985671, -0.379467449594911, 0.971659698265082, 
3.58247384148793, 0.0419848907173112, -0.122469737335293, 1.77830824717609, 
-0.514049917080521, -0.18293712720485, 2.46932240611204, 0.103057098341037, 
0.416637231962104, 0.0347935126703198, -0.0683916002418135, -0.140487578187558, 
0.568918482749401, 0.80596534428402, 0.709398516775547, -0.894621921017203, 
0.0649799930004623, -0.548245672308119, 1.44034787873404, 0.358249364890296, 
2.26940685862186, -0.268422759950531, 1.5743831289976, -0.532610474943992, 
-0.348407548366379, 0.853760927062079, -0.291156653609838, -0.243149894243735, 
-0.844996329920804, 0.648164055112016, 0.273498687490394, 1.25994011310977, 
-0.792349093161483, 1.1447573459646, -0.013918010591892, -0.526596681137413, 
-0.658346817606973, 0.516140362965466, -0.78849760892925, -0.554157331051845, 
0.243647599042001, 2.04373725057226, -0.97656446859546, -0.912356887137583, 
-0.824415433098142, -0.335325459216275, -0.587458300768377, -0.258938778587358, 
-0.555568876651955, 4.63824989974633, -0.412721089416945, 1.00499324152535, 
-0.306301918538513, -0.228161687409781, -0.129884546841866, -0.0114026911643898, 
1.87550512616216, -0.944634183197748, -0.262575555185161, 0.861798664854366, 
-0.0509189201758156, -0.612609020943053, 0.110701490753568, -0.726367609420112, 
-0.375128625430529, -0.99933686605059, -0.19110078784832, -0.371574440790039, 
-0.938132768287129, 0.243368554114898, 0.929595838846939, -0.88523573776942, 
-0.810391389504971, -0.331438750074501, 0.873586339714407), AGE67 = c(0.501041434601308, 
-1.09665902567313, -0.499366515285846, -1.89627995143666, -0.298655565821342, 
-0.657227918269277, 0.738638322699508, 0.844563870334618, -1.56421059601705, 
-0.0241556789265915, -0.641594843326121, -0.478951705500202, 
2.61976255704089, 1.96434823104864, -1.22122015725059, 1.55163432276405, 
1.27827046071975, -0.343864172693654, 3.31198498838896, 1.66372654109949, 
0.37030630478651, -0.419408311079017, -0.555972438485822, -0.774936554113457, 
0.107872632262183, -1.34044707928847, -0.839545557277716, 1.92619155286728, 
0.57973553068234, 0.221251081710207, -0.163884033780637, -0.687948777957268, 
-0.0113182539359213, -0.37327476423973, -1.27171923730041, -1.38681856755255, 
-1.40429288028817, 0.168119278247155, -0.659256439010892, -1.07437656281392, 
-0.696986471528318, -0.864024735676032, -0.534449400062125, -0.180293103433951, 
2.30163523766365, 3.15498511491188, 0.307283716170217, -0.157204374134115, 
-0.518770979831071, -0.401045752552977, 0.997254420383102, 0.232181798069135, 
-0.405998845074965, -0.568631832646478, 0.254930436291304, 0.404278675370718, 
0.392765574171448, 2.47506069286078, -1.8557905876778, -0.0631740179401674, 
0.0298313302450834, -0.647883494529081, -0.207182743392638, -1.39352452511716, 
-1.34117850614179, 0.944792602443884, -1.24991896454826, -1.86690843412548, 
-0.905681536580039, 0.665146208715289, -0.322988621135965, 0.014606612397604, 
-0.911460946841895, -0.218306616580415, 0.26710983425054, 0.514703818235693, 
2.14113331398361, -0.616093681265326, 0.0291501450665863, -0.44680387231444, 
-0.25559820550726, -0.486781511123927, 1.46702523452746, 2.14512158886705, 
0.586724522901828, -0.354832348471152, -0.0891185777716379, 0.684009304312411, 
-0.946069291878499, -0.730270080428137, -0.101775097076054, 1.62556716185319, 
0.37030630478651, -0.819482083082821, 1.64647740419461, -0.556561969797837, 
0.0432415212999982, 0.279105885015719, -0.719800485475731, 0.352766684885891
), AGE67IT = c(0.546552412166264, -1.07526365158326, -0.468955640670514, 
-1.8974512238005, -0.310213066218594, -0.648052048503732, 0.787735574714999, 
0.854938313505668, -1.54987241979056, -0.0707358644941545, -0.726319975489725, 
-0.475290635689973, 2.4034814230678, 2.03194628146796, -1.20170490207676, 
1.61300289199568, 1.25342092240672, -0.311106154213347, 3.34434571314056, 
1.72678702415178, 0.413843973443276, -0.423068606116984, -0.526415970583246, 
-0.795742413966101, 0.0958131387143464, -1.34275483486837, -0.847437129878995, 
1.85139271626387, 0.555660476058683, 0.188567559007975, -0.160260210148408, 
-0.76145871420041, 0.02645919447723, -0.384844658381296, -1.27679963538867, 
-1.36980284573822, -1.38754091479577, 0.208605146969538, -0.669199963819832, 
-1.07508354694689, -0.669558461737779, -0.845642806091824, -0.50456806459325, 
-0.145066152169093, 2.29730541056113, 3.24055457341925, 0.271679866495853, 
-0.173297360212068, -0.50168800913321, -0.399729324986892, 0.977333088835057, 
0.27363462296638, -0.404248928734908, -0.574213826955535, 0.272419254286123, 
0.408398950123785, 0.436642241813503, 2.55036740928169, -1.84585350391826, 
-0.0355778324729171, 0.0525201570882496, -0.619714327085901, 
-0.182564609086414, -1.38913227212412, -1.32347389573356, 0.952476863305029, 
-1.23894360432029, -1.86146509438973, -0.923463257162492, 0.648263392221682, 
-0.358121272963515, -0.00290343411139714, -0.887270203066352, 
-0.205933003746037, 0.309089861018009, 0.520773126763273, 2.11253418114413, 
-0.58744468011722, -0.0301375036843241, -0.425015388384307, -0.228529543651722, 
-0.456180679143214, 1.29606011127543, 2.1676026243595, 0.633528796321265, 
-0.342348134083138, -0.0525154431967221, 0.686070667402912, -0.922400924572503, 
-0.716563246171896, -0.0653629993121732, 1.68805166931954, 0.413843973443276, 
-0.82126700480962, 1.66661355031845, -0.527014400248682, 0.0573323154964607, 
0.255567177377771, -0.732301118039339, 0.396039611477956), DWE1tot = c(-1.24627225843832, 
-0.972515585869268, -0.591117336947849, -0.950950893751152, -0.00341326678369597, 
-0.317703674167421, -1.01860592322038, 1.70500797025228, -2.80147690897104, 
0.0841927598217607, -2.31467874579201, -0.634564452794059, 1.63477150881602, 
2.39176448207346, -0.603097208459657, 0.554218462480144, 0.596025520656533, 
-0.363363481034292, 0.326011018689958, 1.45302282344243, 0.259839566760795, 
-0.323873062136694, -0.645196692893364, 0.126535915869014, 0.273354029798474, 
-1.6765246798592, -0.858119384992372, 0.559696169973685, 0.965787728534269, 
0.168732755000225, -0.105465993392139, -0.476904337564692, 0.158116000800885, 
0.368839813382008, -1.53101094871713, -0.768492476757933, -0.580381980928888, 
0.383757877076662, -2.51125311699879, -0.285564675616673, -0.307208273135031, 
-0.1958624767505, -0.52862437997276, -1.09793370211631, 2.37364271507784, 
1.94326659579375, -0.29069268473587, -1.3926121391431, -0.454695942608588, 
-0.574896772720613, 1.861531778717, 0.988010759757147, 0.0946589558731531, 
-0.4279250695409, -0.641299416040736, 0.211479245514941, -0.594712699321882, 
2.85291296625091, -0.317088857444325, 0.30614248023326, -0.375126422991232, 
-0.631092891479471, -0.286277701436183, -1.74098962562517, -1.24105303470107, 
0.709126059783464, -0.915573987456023, -2.33328457590235, -1.24384520039966, 
0.371696888960207, 0.0337244463841061, 0.718542081472043, -1.15989953982117, 
-0.249183962565405, -0.457050613250607, 0.543066989542432, 1.55057780183794, 
-0.319116485641744, -0.486838166179266, -0.71584875469315, -0.292203836118496, 
-0.0184323959900459, 2.93015633378211, 2.52801762478007, 0.435292040053886, 
-0.373077593601094, 0.295243869739579, 0.425845049612918, -0.475620953880973, 
-1.10368797436764, 0.346155709987028, 0.434151223026415, 0.632305650134888, 
-0.499317186271939, 1.66284666421724, -0.464006541934083, -0.395742174424708, 
0.0804687862477327, -0.714833790842392, 0.529197578800853), AGE18_34NoIT = c(0.19940732465257, 
-0.46709113395183, 2.34119335826173, -0.287924912209743, -0.767850884074149, 
-0.107050721676688, -0.767850884074149, 0.358720441384029, 0.30950776246791, 
0.94233552589833, 0.0931152138034797, 0.47658498239766, -0.767850884074149, 
-0.767850884074149, 0.671325955294912, -0.767850884074149, 0.369480488190256, 
-0.289390335024597, 0.0791536014054641, -0.767850884074149, 2.79341797532877, 
0.225596235651883, -0.767850884074149, -0.767850884074149, -0.472755159377862, 
2.58882706714496, 1.12766318625321, 0.312810011192943, 2.00670978225976, 
-0.486023852035069, 0.136757539237998, 1.68270945706995, -0.767850884074149, 
0.970984615352628, 1.24052217326025, -0.510972474543447, 0.729924457795809, 
0.54106416210258, 0.677682601292349, 1.00412854244753, 0.21920946120918, 
0.458536361427269, -0.767850884074149, -0.767850884074149, -0.767850884074149, 
-0.767850884074149, -0.767850884074149, -0.374142768964279, -0.45000944428971, 
0.0476664458114263, 0.899126028837855, -0.767850884074149, -0.355413039955347, 
-0.0830912490542143, -0.397411570093704, -0.159323389651903, 
0.187611492838829, -0.767850884074149, 0.558206914067417, -0.624618864134726, 
-0.767850884074149, -0.135502902582105, -0.262505861107523, -0.214425105898759, 
-0.664095989312535, -0.0650703296976993, 2.52667379006469, -0.570071981006544, 
-0.471397148679766, 0.344343491575343, -0.248130055504569, -0.485006425926625, 
0.580392393481832, 0.844944764062817, 0.24526870523875, -0.365033840851197, 
-0.767850884074149, 0.202401932109928, 0.298950997178277, -0.552609359604742, 
-0.286313505862145, -0.267225229717189, 0.992776417203698, 0.252968854126036, 
-0.337958346862274, -0.103885503507503, -0.767850884074149, -0.06359996356301, 
-0.498614063088361, -0.264221690329575, -0.767850884074149, 0.0117303587802212, 
-0.767850884074149, -0.455084756701438, -0.767850884074149, -0.510972474543447, 
-0.767850884074149, 0.233400424639771, 0.137033599931982, -0.20822292045369
), AGE67CINo = c(0.520215683696846, -1.03163833351555, -0.509119810566456, 
-1.87130724349216, -0.314585715658465, -0.673792033984295, 0.823598776742872, 
0.903906066078584, -1.50426966218324, 0.02457959376149, -0.57163006325451, 
-0.460147962605314, 2.57888667222345, 2.06262552649387, -1.16895180207709, 
1.5780646265547, 1.26504842710326, -0.301929339443148, 3.09282231729659, 
1.69955169568212, 0.45126498073184, -0.431038644185776, -0.52546864588487, 
-0.776712327050538, -0.0582690276197286, -1.33789498354707, -0.811366964612021, 
2.0240542914681, 0.519263167059362, 0.300590333065505, -0.18388341953198, 
-0.755742030018846, -0.149063665921569, -0.510160314132326, -1.29934861634348, 
-1.32495034122243, -1.41718294872725, 0.1491325448279, -0.589483545018553, 
-1.09646319324319, -0.643748191001391, -0.91776562934997, -0.582479723072687, 
-0.263654292536764, 2.32687904677992, 3.14283459045014, 0.309693195176315, 
-0.0819768879271749, -0.55910582347913, -0.39952005005421, 1.01240633996832, 
0.311639807227703, -0.445268920756996, -0.542619478394475, 0.213605110937565, 
0.406078840439239, 0.349100219631252, 2.57888667222345, -1.8567844317766, 
-0.0711594953896716, 0.0132252207188618, -0.594515172713188, 
-0.366188280170573, -1.45892393133304, -1.45847888111253, 0.949648680774444, 
-1.2753657237661, -1.97252072522736, -1.00402956173803, 0.684707699839703, 
-0.385405489574873, -0.0376751575317439, -0.864005933475389, 
-0.143742981474305, 0.258678534519238, 0.439300527132514, 2.24133130770795, 
-0.545851841402966, 0.0610101153751223, -0.46848814631922, -0.288656266519766, 
-0.524178190926454, 1.55989931716918, 2.15959966135947, 0.567034062253331, 
-0.315126128071832, -0.189309291975311, 0.657931155146952, -0.914598686765673, 
-0.828009160908805, -0.219023614915597, 1.72016373828187, 0.305251727394181, 
-0.819574921820783, 1.40141138963966, -0.619957661890062, 0.0474226036519372, 
0.293647638969982, -0.690104740437684, 0.360397804324386), AGE67CAR2 = c(-0.65814805058186, 
-0.65814805058186, -0.65814805058186, 0.00768964415182037, -0.65814805058186, 
1.19533826054216, -0.65814805058186, -0.231881369541573, -0.65814805058186, 
0.231603684740978, 0.237714616989074, -0.397640616985576, 4.000337820787, 
2.0821377561057, -0.476631762802389, 0.414529617167548, 0.584452971176845, 
-0.160294598679845, 0.223187114271708, 0.284320189290452, -0.65814805058186, 
-0.439244037830135, -0.65814805058186, 0.0880625008959889, 0.365373641181542, 
-0.0231061596455018, -0.289977393005934, 0.466314056300278, -0.464325455824444, 
-0.0716480308052054, -0.428569051007219, 1.23605308567625, 0.101092952900841, 
1.01198772916626, -0.280209804949945, -0.65814805058186, -0.361294148901551, 
-0.65814805058186, -0.0564985837999044, -0.124416160746424, -0.65814805058186, 
-0.41672466856486, -0.116013368959549, -0.65814805058186, 0.563177833297616, 
3.27069545539188, -0.65814805058186, -0.65814805058186, -0.451445457404997, 
-0.415698971774559, -0.65814805058186, 1.37993951814202, -0.308466424440895, 
-0.420643308414475, -0.272694136875217, -0.0249557962210446, 
-0.65814805058186, 1.45934552731308, -0.198212915185781, -0.211035604061083, 
0.338319515486345, 0.131425825921336, -0.0109756149662412, -0.261005180047431, 
0.421454700604434, 0.0983335709673666, -0.101088636265742, 0.0964345427048702, 
-0.508081777689168, -0.400976117619857, -0.117361962694928, 0.813395320337906, 
-0.0346413373967359, -0.481499073877732, 0.0446407662194761, 
-0.0294320910783511, 0.909611617667275, 0.351430930829348, 0.116299410617839, 
-0.210216716796393, -0.398341759098735, -0.65814805058186, -0.65814805058186, 
0.100563003387009, -0.509042421411389, 0.511028651163257, 0.0431293924199034, 
0.660883904007974, -0.65814805058186, -0.378658861813212, -0.273829829019315, 
4.20892674040053, -0.0768753037087408, -0.441185967417642, 2.04802631452868, 
-0.65814805058186, -0.269478192422361, 0.383685850363252, -0.65814805058186, 
-0.65814805058186), AGE18_34FLD = c(-0.130080127803667, -1.09592168708659, 
1.38767089392664, 0.501488548787159, 2.32784926704544, -0.235270792676064, 
1.13931163582532, 0.540327542757423, 0.577515025253877, -0.242080854323515, 
-0.236216562889702, 0.683441681810391, -1.09592168708659, -1.09592168708659, 
0.0363145076530911, -0.0665379199561058, 0.777926607570102, -0.618161892876412, 
-0.250157727065868, 1.61736731760417, -1.09592168708659, 0.864721549837708, 
-1.09592168708659, -0.379828631691471, 0.475617121238168, 0.122900767471024, 
-0.439772550360837, -0.0168435311980823, 1.88907511940198, -0.533093152540427, 
0.358144484675311, 1.00148223521681, -0.367324096381894, -0.561678980799494, 
-0.519893465065387, -0.0699129487007959, 0.470878184230816, 0.0243623172368012, 
0.0588113948605217, 0.937318028328768, 0.13609667987273, -0.268495285162563, 
-0.315540975845525, -1.09592168708659, 1.24814322083601, 2.67435138770459, 
-1.09592168708659, -1.09592168708659, -0.342153857605648, -0.165266920669736, 
0.0137686150682584, -1.09592168708659, 0.246351394326684, -0.488137923633972, 
1.49335568460806, 0.119350799156778, -1.09592168708659, 0.936108606469693, 
-0.213177920633368, -0.523832719613707, 0.577515025253877, -0.843352950034398, 
1.00012940708533, -0.486139365405593, -0.292998940733529, 0.743150579831344, 
1.1246254773305, -0.722886464554487, -0.455879420465601, -0.849129364367611, 
-0.576962043292781, 0.0337991259374796, -0.198409262972822, -0.756883046558241, 
2.276197550065, -0.29146753488179, -1.09592168708659, -0.127089906257899, 
0.440004947086572, 0.336920186574891, -0.205491840718313, 0.570386434359988, 
-1.09592168708659, 0.0690207613365549, 0.478042335448546, -0.0759325566835026, 
-0.422948213521715, -0.39270221470981, -0.0205517035550875, -0.626556215380212, 
-0.542711318422477, -1.09592168708659, 0.577515025253877, -0.67951095494066, 
0.202554102156843, -0.582917317893694, -0.349955858224452, -0.0961368142186437, 
-0.794735291314053, 0.39423386152135), AGE18_34IT = c(-0.0922984151251803, 
-1.10842633780253, 0.18885009667243, 0.507456373441544, 0.420167034168932, 
-0.215706837795418, -1.55427067647275, 0.162389060267949, 0.754154056222705, 
-0.707142977760336, -0.503207778521687, -0.0327393224943708, 
-1.04820335523705, -0.869591359506808, -0.58143266150616, -0.428936764777582, 
0.330878305136329, -1.19500150948863, -1.29439826827064, 0.60163461688395, 
1.16009041015508, 0.175009713455721, -1.36162571986824, -0.27694286108837, 
1.14189589937619, 2.21684247777924, -0.56335821198866, -0.594487825853325, 
1.54764337967806, -0.786067764381153, -0.234733270862691, 2.52999561314261, 
0.331551786013243, -0.0645795477648524, -0.785164273078549, -0.318139022962605, 
-0.125442459074011, -0.0312375211548161, 0.292594934974131, 0.726270176261751, 
-1.00038597055337, -0.164525780028025, -0.293519419828208, 1.40909260663938, 
0.0624912149656389, 2.63450341616489, 0.0947775908237999, 1.49993461399998, 
-0.497483035068795, -0.462494708354598, 0.200813863131902, -0.289102373383507, 
0.136007989951257, -0.873494025109071, 0.870784784862377, 0.553524282848747, 
0.963722824309746, -0.220093193215003, 0.63773601395719, -0.11748356517652, 
-0.0577186516420515, -0.859500286301709, 1.15609583262114, -0.921362952472097, 
-0.383220339847388, 0.525567398769661, 1.35644767444532, -0.823292431599732, 
-0.767028966908216, -1.22221388525185, 0.0545702104407355, -0.142032267392627, 
-0.637930804177903, -1.33111856407521, 6.33776194521154, -0.147949920031134, 
0.664639885868386, -0.963597673049043, -0.142153756039259, -0.0490447979255479, 
0.00402038984987186, 2.24042581854381, 0.3278448665499, -0.378941686111017, 
1.01725476517897, -0.097430200216662, -0.536694449902048, 0.518899495690568, 
-0.721259518837359, -0.390435116407882, -0.933093169027874, -0.297914490917462, 
-0.165968346024019, -0.646177818606699, 0.879472167146227, 1.30791335346685, 
-0.79366174670134, -1.38127002672125, -0.969297767132932, 0.976065929705738
)), row.names = c(6737L, 3053L, 831L, 2255L, 6090L, 5183L, 347L, 
3260L, 2795L, 4098L, 2961L, 4487L, 576L, 1838L, 3515L, 6756L, 
3888L, 5386L, 7080L, 145L, 1236L, 1962L, 1096L, 7603L, 6386L, 
7120L, 2560L, 5374L, 3771L, 13L, 3489L, 6914L, 6893L, 5378L, 
6236L, 1912L, 1734L, 6587L, 2806L, 5165L, 3419L, 7584L, 5958L, 
7661L, 5073L, 5789L, 828L, 2947L, 6510L, 2500L, 274L, 1024L, 
5486L, 4215L, 7079L, 7258L, 2931L, 4856L, 2683L, 6654L, 6953L, 
1424L, 6876L, 6027L, 7459L, 3952L, 6722L, 6039L, 6223L, 3723L, 
6206L, 5029L, 3131L, 3807L, 7124L, 3610L, 960L, 466L, 4465L, 
5901L, 6073L, 6863L, 2636L, 4187L, 5715L, 4266L, 7746L, 4024L, 
3481L, 6300L, 7738L, 1006L, 3714L, 1952L, 3997L, 6171L, 5086L, 
2553L, 4783L, 7212L), class = "data.frame")

training_set$Target<-structure(c(2L, 3L, 1L, 4L, 4L, 3L, 1L, 1L, 3L, 3L, 3L, 4L, 3L, 
4L, 4L, 2L, 4L, 3L, 3L, 2L, 3L, 3L, 4L, 1L, 2L, 4L, 2L, 3L, 1L, 
3L, 2L, 4L, 1L, 1L, 4L, 1L, 1L, 4L, 2L, 3L, 3L, 2L, 4L, 2L, 2L, 
3L, 2L, 1L, 1L, 4L, 2L, 3L, 4L, 3L, 4L, 2L, 2L, 3L, 4L, 1L, 2L, 
1L, 3L, 3L, 4L, 3L, 3L, 3L, 2L, 1L, 3L, 2L, 2L, 3L, 4L, 4L, 1L, 
2L, 2L, 2L, 2L, 2L, 2L, 2L, 3L, 2L, 1L, 2L, 2L, 2L, 2L, 3L, 2L, 
3L, 1L, 3L, 4L, 4L, 2L, 4L), levels = c("Q1", "Q2", "Q3", "Q4"
), class = "factor")
r caret rpart
1个回答
0
投票

这是一个复杂的问题。将

rpart
转换为
igraph
对象可能是最简单的。这将允许您选择以特定类结尾的子图。

第 1 步:重现问题

您的示例数据不包含

Target
列,因此不足以重现您的问题。但是,我们可以使用以下方法轻松创建类似的东西:

training_set <- expand.grid(A = 1:4, B = 1:4, C = 1:4, D = 1:4)

training_set$Target <- cut(rowSums(training_set), 
                           breaks = c(0, 8, 10, 12, 20),
                           labels = c('w', 'x', 'y', 'z'))

现在我们可以使用您自己的代码来创建

classifier
对象:

library(rpart)
library(caret)

fitControl <- trainControl(method = "repeatedcv",
                           number = 10,
                           repeats = 10)

classifier = train(x = training_set[, names(training_set) != "Target"],
                   y = training_set$Target,
                   method = 'rpart',
                   parms = list(split = "gini"), trControl = fitControl,
                   tuneLength = 20)
classifier
complexity_parameter=classifier$bestTune

classifier = rpart(formula = Target ~ .,
                   data = training_set,parms = list(split = "information"),
                   control = rpart.control(cp = complexity_parameter))

我们现在在精美的 rpart 图的输出中遇到了类似的易读性问题:

library(RColorBrewer)
library(rattle)
fancyRpartPlot(classifier, caption = NULL, clip.right.labs=FALSE,branch=.3,type=3,
               tweak=1.4)

第 2 步:将

rpart
转换为
igraph

我找不到将

rpart
转换为
igraph
二叉树的现有方法。
data.tree
包允许
rpart
Node
igraph
,但最终结果不是二叉树。

我在这里使用的方法是在

igraph
中创建一个二叉树,复制节点属性并删除
rpart
对象中缺少的顶点:

library(igraph)

df <- classifier$frame
nodes <- as.numeric(row.names(df))
non_nodes <- setdiff(seq(max(nodes)), nodes)

g <- graph.tree(max(nodes), mode = 'out')
labs <- ifelse(df$var == '<leaf>', 
               levels(training_set$Target)[df$yval], 
               labels(classifier))
classed <- ifelse(df$var == '<leaf>', df$yval, NA)
vertex_attr(g, 'name') <-  labs[match(V(g), nodes)]
vertex_attr(g, 'number') <- df$n[match(V(g), nodes)]
vertex_attr(g, 'class') <- as.character(classed[match(V(g), nodes)])
g <- delete.vertices(g, non_nodes)

第 3 步:通过使用 ggraph

 进行绘图,确保我们的 igraph

是正确的

我们将使用

igraph
绘制
ggraph
以检查其是否正确。请注意,我们仍然遇到易读性问题,因为我们尚未转换为子图:

library(ggraph)

ggraph(g, layout = 'tree') +
  geom_edge_diagonal() +
  geom_node_label(aes(label = paste(name, number, sep = '\n n = '),
                      fill = class)) +
  scale_fill_manual(values = c(`1` = 'lightgreen', `2` = 'lightblue',
                               `3` = 'orange', `4` = '#E0A8FF'),
                    na.value = 'white', guide = 'none') +
  theme_graph()

第4步:获取以每个类结尾的子图

这涉及获取表示到包含目标类的每个节点的路径的子组件,然后获取诱导子图:

subs <- lapply(levels(training_set$Target), function(n) {
  which(V(g)$name == n) |>
  lapply(function(x) subcomponent(g, x, 'in')) |>
  unlist() |>
  unique()})

subs <- lapply(subs, function(x) {
  induced.subgraph(g, x)
})

第 5 步:绘制结果

这实际上只是为每个子图复制上面的绘图代码。为了方便起见,我们将它们放在一个列表中:

plots <- lapply(subs, function(x) {
  ggraph(x, layout = 'tree') +
    geom_edge_diagonal() +
    geom_node_label(aes(label = paste(name, number, sep = '\n n = '),
                        fill = class)) +
    scale_fill_manual(values = c(`1` = 'lightgreen', `2` = 'lightblue',
                                 `3` = 'orange', `4` = '#E0A8FF'),
                      na.value = 'white', guide = 'none') +
    theme_graph()
})

现在我们有:

plots[[1]]

plots[[2]]

plots[[3]]

plots[[4]]

我们可以看到,这为我们提供了具有正确分区路径的正确子树。每个节点显示的信息都是可定制的 - 只需在步骤 2 中将其从

rpart
复制到
igraph
即可。

© www.soinside.com 2019 - 2024. All rights reserved.