无法使用geom_contour()绘制R中的决策边界

问题描述 投票:0回答:1

我知道对此有几个类似的问题,但我仍在努力理解其背后的概念。

我根据knn预测了类(“ true”和“ false”):

hype_knn_prediction = knn(hype_quant_train[,2:3], hype_quant_test[,2:3], cl = hype_quant_train$Selected, k = k)

而且我也能够绘制测试数据:

ggplot(hype_quant_test, aes(x= Comments, y= Votings, color = Selected, shape = Selected)) + 
  geom_point(size = 3) +
  labs(y="Votes", x = "Comments")+
  ggtitle("Testdata")+
  theme(plot.title = element_text(hjust = 0.5))+
  theme(legend.position="bottom")

Plot of test data

现在,我想将类的决策边界添加到测试数据的绘图中。因此,我将hype_knn_prediction作为列添加到hype_quant_test数据帧中。但是,当我在绘图中添加geom_contour(data=hype_quant_test, aes(x=Comments, y=Votings, z= as.numberic(Selected)), breaks=c(0,.5))时,会收到以下消息:在stat_contour()中计算失败:x坐标数必须与密度矩阵中的列数匹配。

我该如何解决此问题?我认为我必须转换一些数据,但不知道如何

编辑

培训数据:

   Selected  Votings          Comments
1      true  0.2348563517    0.162454874
2     false  0.0027691243    0.001805054
3     false  0.0136725511    0.027075812
4     false  0.1128418138    0.077617329
5     false  0.0529595016    0.016245487
6     false  0.0190377293    0.012635379
7     false  0.0231914157    0.001805054
8     false  0.3367947387    0.019855596
9     false  0.0036344756    0.005415162
10    false  0.0051921080    0.005415162
11    false  0.0202492212    0.014440433
12    false  0.0178262375    0.007220217
13    false  0.0029421945    0.010830325
14    false  0.0680166147    0.036101083
15    false  0.0053651783    0.003610108
16    false  0.2397023191    0.034296029
17    false  0.0001730703    0.000000000
18    false  0.0228452752    0.023465704
19    false  0.0129802700    0.000000000
20    false  0.0192107996    0.018050542
21    false  0.0010384216    0.000000000
22    false  0.0129802700    0.005415162
23    false  0.0000000000    0.000000000
24    false  0.0134994808    0.003610108
25    false  0.0742471443    0.039711191
26    false  0.0256143994    0.009025271
27    false  0.0039806161    0.001805054
28     true  0.4110418830    0.050541516
29    false  0.0114226376    0.063176895
30    false  0.0185185185    0.016245487
31    false  0.0051921080    0.003610108
32    false  0.1952232606    0.021660650
33    false  0.1138802354    0.012635379
34    false  0.0048459675    0.016245487
35    false  0.0242298373    0.009025271
36    false  0.0167878159    0.001805054
37    false  0.0039806161    0.001805054
38     true  0.7727587400    0.146209386
39    false  0.0154032537    0.000000000
40    false  0.0057113188    0.007220217
41    false  0.0038075459    0.000000000
42    false  0.0046728972    0.001805054
43    false  0.0152301835    0.003610108
44    false  0.0408445829    0.025270758
45    false  0.0131533403    0.007220217
46    false  0.0578054690    0.037906137
47    false  0.0046728972    0.005415162
48    false  0.0001730703    0.001805054
49    false  0.1169955002    0.122743682
50    false  0.0044998269    0.003610108
51    false  0.0000000000    0.000000000
52    false  0.1439944618    0.036101083
53    false  0.0072689512    0.005415162
54    false  0.0064035999    0.009025271
55    false  0.0614399446    0.027075812
56    false  0.0719972309    0.005415162
57     true  0.3418137764    0.018050542
58    false  0.0117687781    0.012635379
59    false  0.0072689512    0.014440433
60     true  0.0313257182    0.018050542
61    false  0.1021114573    0.019855596
62    false  0.0024229837    0.003610108
63    false  0.0072689512    0.000000000
64    false  0.0169608861    0.003610108
65    false  0.0340948425    0.014440433
66     true  0.7069920388    0.332129964
67     true  0.7377985462    0.175090253
68    false  0.0919003115    0.007220217
69    false  0.0065766701    0.001805054
70    false  0.0401523018    0.027075812
71    false  0.0223260644    0.005415162
72    false  0.0635167878    0.018050542
73    false  0.0013845621    0.000000000
74    false  0.0060574593    0.000000000
75     true  0.6102457598    0.909747292
76    false  0.0022499135    0.001805054
77    false  0.0316718588    0.007220217
78    false  0.0019037729    0.000000000
79     true  1.0000000000    1.000000000
80    false  0.0240567670    0.016245487

添加预测列后的测试数据:

   Selected   Votings          Comments   Prediction
1     false   0.329525787    0.023465704      false
2     false   0.299930772    0.075812274      false
3      true   0.962443752    0.178700361       true
4     false   0.032191070    0.001805054      false
5     false   0.036863967    0.025270758      false
6     false   0.014884043    0.005415162      false
7     false   0.034787124    0.005415162      false
8     false   0.007615092    0.000000000      false
9     false   0.005538249    0.000000000      false
10    false   0.006403600    0.005415162      false
11    false   0.006749740    0.005415162      false
12    false   0.048286604    0.072202166      false
13    false   0.057286258    0.021660650      false
14    false   0.067324334    0.012635379      false
15    false   0.004153686    0.001805054      false
16    false   0.004845967    0.003610108      false
17    false   0.089131187    0.055956679      false
18    false   0.010384216    0.001805054      false
19    false   0.040671513    0.021660650      false
20    false   0.001903773    0.001805054      false

编辑

我尝试使用默认的虹膜数据测试绘图,但消息仍然相同:

library(datasets)
library(class)
library(ggplot2)
library(caret)

iris_df = as.data.frame(iris)

normalize = function(x) {
  return ((x - min(x)) / (max(x) - min(x)))
}

iris_df$Sepal.Length = normalize(iris_df$Sepal.Length)
iris_df$Sepal.Width = normalize(iris_df$Sepal.Width)
iris_df$Petal.Length = normalize(iris_df$Petal.Length)
iris_df$Petal.Width = normalize(iris_df$Petal.Width)

set.seed(1234)

#sampling
sample = sample(nrow(iris_df), nrow(iris_df)*0.8, replace = FALSE)

#Training data
iris_train=iris_df[sample,]

#Testdata
iris_test=iris_df[-sample,]


ggplot(iris_train, aes(x= Sepal.Length, y= Sepal.Width, color = Species, shape = Species)) + 
  geom_point(size = 3) +
  theme(legend.position="bottom")

ggplot(iris_test, aes(x= Sepal.Length, y= Sepal.Width, color = Species, shape = Species)) + 
  geom_point(size = 3) +
  theme(legend.position="bottom")

k = round(sqrt(nrow(iris_train)))


knn_predict = knn(iris_train[,1:4], iris_test[,1:4], cl = iris_train$Species, k = k)

iris_test$Prediction = knn_predict

confusionMatrix(iris_test$Prediction, iris_test$Species)

ggplot(iris_test, aes(x= Sepal.Length, y= Sepal.Width, color = Species, shape = Species)) + 
  geom_point(size = 3) +
  geom_contour(data = iris_test, aes(x= Sepal.Length, y= Sepal.Width, z = as.numeric(Prediction)),breaks = c(0,.5))

stat_contour()中的计算失败:x坐标数必须与密度矩阵中的列数匹配。

r ggplot2 plot knn
1个回答
2
投票

[我认为geom_contour方法的一个关键部分(在虹膜代码中没有发生)是对变量矩阵进行预测。这是您的制作方法。我对预处理没有任何幻想。

library(class)
library(ggplot2)

train <- sample(150, 75)

train_dat <- iris[train, -5]
test_dat <- iris[-train, -5]

vars <- c("Sepal.Width", "Sepal.Length")

# First make a grid
n <- 40
pred.mat <- expand.grid(
  Sepal.Width = with(iris, seq(min(Sepal.Width), max(Sepal.Width), length.out = n)),
  Sepal.Length = with(iris, seq(min(Sepal.Length), max(Sepal.Length), length.out = n))
)

# Then ask for prediction on the grid
pred.mat$pred <- knn(train_dat[, vars], pred.mat, cl = iris$Species[train], k = 3)

# Use grid as input for geom_contour
ggplot(pred.mat, aes(Sepal.Width, Sepal.Length)) +
  geom_point(data = iris, aes(color = Species)) +
  geom_contour(aes(z = as.numeric(pred == "setosa"), 
                   colour = "setosa"), 
               breaks = 0.5) +
  geom_contour(aes(z = as.numeric(pred == "virginica"), 
                   colour = "virginica"), 
               breaks = 0.5) +
  geom_contour(aes(z = as.numeric(pred == "versicolor"), 
                   color = "versicolor"), 
               breaks = 0.5)

“”

reprex package(v0.3.0)在2020-04-18创建

© www.soinside.com 2019 - 2024. All rights reserved.