我知道对此有几个类似的问题,但我仍在努力理解其背后的概念。
我根据knn预测了类(“ true”和“ false”):
hype_knn_prediction = knn(hype_quant_train[,2:3], hype_quant_test[,2:3], cl = hype_quant_train$Selected, k = k)
而且我也能够绘制测试数据:
ggplot(hype_quant_test, aes(x= Comments, y= Votings, color = Selected, shape = Selected)) +
geom_point(size = 3) +
labs(y="Votes", x = "Comments")+
ggtitle("Testdata")+
theme(plot.title = element_text(hjust = 0.5))+
theme(legend.position="bottom")
现在,我想将类的决策边界添加到测试数据的绘图中。因此,我将hype_knn_prediction
作为列添加到hype_quant_test
数据帧中。但是,当我在绘图中添加geom_contour(data=hype_quant_test, aes(x=Comments, y=Votings, z= as.numberic(Selected)), breaks=c(0,.5))
时,会收到以下消息:在stat_contour()中计算失败:x坐标数必须与密度矩阵中的列数匹配。
我该如何解决此问题?我认为我必须转换一些数据,但不知道如何
编辑
培训数据:
Selected Votings Comments
1 true 0.2348563517 0.162454874
2 false 0.0027691243 0.001805054
3 false 0.0136725511 0.027075812
4 false 0.1128418138 0.077617329
5 false 0.0529595016 0.016245487
6 false 0.0190377293 0.012635379
7 false 0.0231914157 0.001805054
8 false 0.3367947387 0.019855596
9 false 0.0036344756 0.005415162
10 false 0.0051921080 0.005415162
11 false 0.0202492212 0.014440433
12 false 0.0178262375 0.007220217
13 false 0.0029421945 0.010830325
14 false 0.0680166147 0.036101083
15 false 0.0053651783 0.003610108
16 false 0.2397023191 0.034296029
17 false 0.0001730703 0.000000000
18 false 0.0228452752 0.023465704
19 false 0.0129802700 0.000000000
20 false 0.0192107996 0.018050542
21 false 0.0010384216 0.000000000
22 false 0.0129802700 0.005415162
23 false 0.0000000000 0.000000000
24 false 0.0134994808 0.003610108
25 false 0.0742471443 0.039711191
26 false 0.0256143994 0.009025271
27 false 0.0039806161 0.001805054
28 true 0.4110418830 0.050541516
29 false 0.0114226376 0.063176895
30 false 0.0185185185 0.016245487
31 false 0.0051921080 0.003610108
32 false 0.1952232606 0.021660650
33 false 0.1138802354 0.012635379
34 false 0.0048459675 0.016245487
35 false 0.0242298373 0.009025271
36 false 0.0167878159 0.001805054
37 false 0.0039806161 0.001805054
38 true 0.7727587400 0.146209386
39 false 0.0154032537 0.000000000
40 false 0.0057113188 0.007220217
41 false 0.0038075459 0.000000000
42 false 0.0046728972 0.001805054
43 false 0.0152301835 0.003610108
44 false 0.0408445829 0.025270758
45 false 0.0131533403 0.007220217
46 false 0.0578054690 0.037906137
47 false 0.0046728972 0.005415162
48 false 0.0001730703 0.001805054
49 false 0.1169955002 0.122743682
50 false 0.0044998269 0.003610108
51 false 0.0000000000 0.000000000
52 false 0.1439944618 0.036101083
53 false 0.0072689512 0.005415162
54 false 0.0064035999 0.009025271
55 false 0.0614399446 0.027075812
56 false 0.0719972309 0.005415162
57 true 0.3418137764 0.018050542
58 false 0.0117687781 0.012635379
59 false 0.0072689512 0.014440433
60 true 0.0313257182 0.018050542
61 false 0.1021114573 0.019855596
62 false 0.0024229837 0.003610108
63 false 0.0072689512 0.000000000
64 false 0.0169608861 0.003610108
65 false 0.0340948425 0.014440433
66 true 0.7069920388 0.332129964
67 true 0.7377985462 0.175090253
68 false 0.0919003115 0.007220217
69 false 0.0065766701 0.001805054
70 false 0.0401523018 0.027075812
71 false 0.0223260644 0.005415162
72 false 0.0635167878 0.018050542
73 false 0.0013845621 0.000000000
74 false 0.0060574593 0.000000000
75 true 0.6102457598 0.909747292
76 false 0.0022499135 0.001805054
77 false 0.0316718588 0.007220217
78 false 0.0019037729 0.000000000
79 true 1.0000000000 1.000000000
80 false 0.0240567670 0.016245487
添加预测列后的测试数据:
Selected Votings Comments Prediction
1 false 0.329525787 0.023465704 false
2 false 0.299930772 0.075812274 false
3 true 0.962443752 0.178700361 true
4 false 0.032191070 0.001805054 false
5 false 0.036863967 0.025270758 false
6 false 0.014884043 0.005415162 false
7 false 0.034787124 0.005415162 false
8 false 0.007615092 0.000000000 false
9 false 0.005538249 0.000000000 false
10 false 0.006403600 0.005415162 false
11 false 0.006749740 0.005415162 false
12 false 0.048286604 0.072202166 false
13 false 0.057286258 0.021660650 false
14 false 0.067324334 0.012635379 false
15 false 0.004153686 0.001805054 false
16 false 0.004845967 0.003610108 false
17 false 0.089131187 0.055956679 false
18 false 0.010384216 0.001805054 false
19 false 0.040671513 0.021660650 false
20 false 0.001903773 0.001805054 false
编辑
我尝试使用默认的虹膜数据测试绘图,但消息仍然相同:
library(datasets)
library(class)
library(ggplot2)
library(caret)
iris_df = as.data.frame(iris)
normalize = function(x) {
return ((x - min(x)) / (max(x) - min(x)))
}
iris_df$Sepal.Length = normalize(iris_df$Sepal.Length)
iris_df$Sepal.Width = normalize(iris_df$Sepal.Width)
iris_df$Petal.Length = normalize(iris_df$Petal.Length)
iris_df$Petal.Width = normalize(iris_df$Petal.Width)
set.seed(1234)
#sampling
sample = sample(nrow(iris_df), nrow(iris_df)*0.8, replace = FALSE)
#Training data
iris_train=iris_df[sample,]
#Testdata
iris_test=iris_df[-sample,]
ggplot(iris_train, aes(x= Sepal.Length, y= Sepal.Width, color = Species, shape = Species)) +
geom_point(size = 3) +
theme(legend.position="bottom")
ggplot(iris_test, aes(x= Sepal.Length, y= Sepal.Width, color = Species, shape = Species)) +
geom_point(size = 3) +
theme(legend.position="bottom")
k = round(sqrt(nrow(iris_train)))
knn_predict = knn(iris_train[,1:4], iris_test[,1:4], cl = iris_train$Species, k = k)
iris_test$Prediction = knn_predict
confusionMatrix(iris_test$Prediction, iris_test$Species)
ggplot(iris_test, aes(x= Sepal.Length, y= Sepal.Width, color = Species, shape = Species)) +
geom_point(size = 3) +
geom_contour(data = iris_test, aes(x= Sepal.Length, y= Sepal.Width, z = as.numeric(Prediction)),breaks = c(0,.5))
stat_contour()
中的计算失败:x坐标数必须与密度矩阵中的列数匹配。
[我认为geom_contour
方法的一个关键部分(在虹膜代码中没有发生)是对变量矩阵进行预测。这是您的制作方法。我对预处理没有任何幻想。
library(class)
library(ggplot2)
train <- sample(150, 75)
train_dat <- iris[train, -5]
test_dat <- iris[-train, -5]
vars <- c("Sepal.Width", "Sepal.Length")
# First make a grid
n <- 40
pred.mat <- expand.grid(
Sepal.Width = with(iris, seq(min(Sepal.Width), max(Sepal.Width), length.out = n)),
Sepal.Length = with(iris, seq(min(Sepal.Length), max(Sepal.Length), length.out = n))
)
# Then ask for prediction on the grid
pred.mat$pred <- knn(train_dat[, vars], pred.mat, cl = iris$Species[train], k = 3)
# Use grid as input for geom_contour
ggplot(pred.mat, aes(Sepal.Width, Sepal.Length)) +
geom_point(data = iris, aes(color = Species)) +
geom_contour(aes(z = as.numeric(pred == "setosa"),
colour = "setosa"),
breaks = 0.5) +
geom_contour(aes(z = as.numeric(pred == "virginica"),
colour = "virginica"),
breaks = 0.5) +
geom_contour(aes(z = as.numeric(pred == "versicolor"),
color = "versicolor"),
breaks = 0.5)
由reprex package(v0.3.0)在2020-04-18创建