mnist数据集的输出如何

问题描述 投票:0回答:1

我使用mnist数据集测试我的网络。因此,模型的输出形状为10。如何重塑输出?例如,如果输出是标签3,则输出[0 0 0 1 0 0 0 0 0 0 0]还是[0 0 0 3 0 0 0 0 0 0]或完全不同?

事情是我不想使用数据加载器。我使用这种方法:

from mlxtend.data import loadlocal_mnist
X, y = loadlocal_mnist(
            images_path='/home/wai043/data/mnist/train-images-idx3-ubyte', 
            labels_path='/home/wai043/data/mnist/train-labels-idx1-ubyte')```

python neural-network pytorch mnist
1个回答
0
投票

如果标签为3,则输出必须为[0、0、0、1、0、0、0、0、0、0]。从loadlocal_mnist获得的y参数具有直接标签,因此您需要训练前先“一次性编码”。

您可以使用以下代码进行编码

from mlxtend.preprocessing import one_hot
from mlxtend.data import loadlocal_mnist

X, y = loadlocal_mnist(images_path='/home/wai043/data/mnist/train-images-idx3-ubyte', 
                       labels_path='/home/wai043/data/mnist/train-labels-idx1-ubyte')
y = one_hot(y)
© www.soinside.com 2019 - 2024. All rights reserved.