MNIST数据集提升

问题描述 投票:0回答:1

我试图将梯度提升应用于MNIST数据集。这是我的代码。

library(dplyr)
library(caret)
mnist <- snedata::download_mnist()
mnist_num <- as.data.frame(lapply(mnist[1:10000,], as.numeric)) %>%
mutate(id = row_number())
mnist_num <- mnist_num[,sapply(mnist_num, function(x){max(x) - min(x) > 0})]

mnist_train <- sample_frac(mnist_num, .70)
mnist_test <- anti_join(mnist_num, mnist_train, by = 'id')

set.seed(5000)
library(gbm)
boost_mnist<-gbm(Label~ .,data=mnist_train, distribution="bernoulli", n.trees=70, 
interaction.depth=4, shrinkage=0.3)

它显示以下错误:

"Error in gbm.fit(x = x, y = y, offset = offset, distribution = distribution, : Bernoulli requires the response to be in {0,1}"

这里有什么问题?谁能告诉我正确的代码?

classification mnist boosting
1个回答
0
投票

错误

Error in gbm.fit(x = x, y = y, offset = offset, distribution = distribution, : Bernoulli requires the response to be in {0,1}

是由于选择的分布,你应该选择的是 multinomial 而非 bernoulli因为伯努利分布只适用于二分反应和 mnist 标签从1到10。

© www.soinside.com 2019 - 2024. All rights reserved.