如何将 MNIST 的 numpy 数组转换为 pytorch 数据集/数据

问题描述 投票:0回答:1

在我的 ML 课上,我的老师给了我们 MNIST 数据,让我们使用 CNN 对其进行训练。

数据位于 matlab

.mat
文件中

我设法将其变成一个

(60000, 784)

的 numpy 数组

(60000个训练数据,每个数据为28x28=784)

标签(数字

0-9
)也存储在
(60000, 1)
数组中

现在我需要将其加载到

torch.utils.data.DataLoader

但我在互联网上找到的只是 pytorch 本身的数据集

torchvision.datasets.MNIST

我不确定我的数据是否与 pytorch 具有相同的形状和结构

有什么想法吗?谢谢!

python numpy pytorch conv-neural-network mnist
1个回答
0
投票

您可以将

DataLoader
与您的数据一起使用:

from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)

但这还不够。你的目标在哪里?

我使用了

这里
mat文件。完整示例:

import scipy.io
import torch
from torch.utils.data import TensorDataset, DataLoader

mnist = scipy.io.loadmat('mnist_uint8.mat')

# Extract data from mat file and convert numpy array as tensor
X_train = torch.Tensor(mnist['train_x'])
y_train = torch.Tensor(mnist['train_y'])
X_test = torch.Tensor(mnist['test_x'])
y_test = torch.Tensor(mnist['test_y'])

# Create tensor datasets with features (X) and target (y)
train_ds = TensorDataset(X_train, y_train)
test_ds = TensorDataset(X_test, y_test)

# Then make data loaders
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=8, shuffle=True)

输入数据:

>>> mnist
{'__header__': b'MATLAB 5.0 MAT-file, Platform: PCWIN, Created on: Wed Feb 22 20:38:11 2012',
 '__version__': '1.0',
 '__globals__': [],
 'train_x': array([[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
 'train_y': array([[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 1, 0],
        ...,
        [1, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 1, 0],
        [0, 0, 1, ..., 0, 0, 0]], dtype=uint8),
 'test_x': array([[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
 'test_y': array([[0, 0, 1, ..., 0, 0, 0],
        [0, 1, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 1],
        ...,
        [1, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 1],
        [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)}
© www.soinside.com 2019 - 2024. All rights reserved.