使用pytorch和sklearn对MNIST数据集进行交叉验证

问题描述 投票:0回答:1

我是pytorch的新手,正在尝试实现前馈神经网络来对mnist数据集进行分类。尝试使用交叉验证时遇到一些问题。我的数据具有以下形状:x_train:torch.Size([45000, 784])和y_train:torch.Size([45000])

我尝试使用sklearn的KFold。

kfold =KFold(n_splits=10)

这是我的火车方法的第一部分,其中我将数据分为折叠部分:

for  train_index, test_index in kfold.split(x_train, y_train): 
        x_train_fold = x_train[train_index]
        x_test_fold = x_test[test_index]
        y_train_fold = y_train[train_index]
        y_test_fold = y_test[test_index]
        print(x_train_fold.shape)
        for epoch in range(epochs):
         ...

y_fold的索引是正确的,它很简单:[ 0 1 2 ... 4497 4498 4499],但不是用于x_fold,即[ 4500 4501 4502 ... 44997 44998 44999]

我希望将变量x_fold作为前4500张图片,换句话说,其形状为torch.Size([4500, 784]),但形状为torch.Size([40500, 784])

关于正确处理方法的任何提示吗?

scikit-learn pytorch cross-validation mnist k-fold
1个回答
2
投票

您弄乱了发票。

x_train = x[train_index]
x_test = x[test_index]
y_train = y[train_index]
y_test = y[test_index]
    x_fold = x_train[train_index]
    y_fold = y_train[test_index]

应该是:

x_fold = x_train[train_index]
y_fold = y_train[train_index]
© www.soinside.com 2019 - 2024. All rights reserved.