批量读取Cifar10数据集

问题描述 投票:0回答:5

我正在尝试读取 CIFAR10 数据集,从 https://www.cs.toronto.edu/~kriz/cifar.html> 分批给出。我正在尝试使用 pickle 将其放入数据框中并读取其中的“数据”部分。但是我收到了这个错误。

KeyError                                  Traceback (most recent call last)
<ipython-input-24-8758b7a31925> in <module>()
----> 1 unpickle('datasets/cifar-10-batches-py/test_batch')

<ipython-input-23-04002b89d842> in unpickle(file)
      3     fo = open(file, 'rb')
      4     dict = pickle.load(fo, encoding ='bytes')
----> 5     X = dict['data']
      6     fo.close()
      7     return dict

KeyError:“数据”。

我正在使用 ipython,这是我的代码:

def unpickle(file):

 fo = open(file, 'rb')
 dict = pickle.load(fo, encoding ='bytes')
 X = dict['data']
 fo.close()
 return dict

unpickle('datasets/cifar-10-batches-py/test_batch')
python-3.x machine-learning computer-vision batch-processing
5个回答
11
投票

您可以通过下面给出的代码读取 cifar 10 数据集,只需确保您提供了放置批次的写入目录

import tensorflow as tf
import pandas as pd
import numpy as np
import math
import timeit
import matplotlib.pyplot as plt
from six.moves import cPickle as pickle
import os
import platform
from subprocess import check_output
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

%matplotlib inline


img_rows, img_cols = 32, 32
input_shape = (img_rows, img_cols, 3)
def load_pickle(f):
    version = platform.python_version_tuple()
    if version[0] == '2':
        return  pickle.load(f)
    elif version[0] == '3':
        return  pickle.load(f, encoding='latin1')
    raise ValueError("invalid python version: {}".format(version))

def load_CIFAR_batch(filename):
    """ load single batch of cifar """
    with open(filename, 'rb') as f:
        datadict = load_pickle(f)
        X = datadict['data']
        Y = datadict['labels']
        X = X.reshape(10000,3072)
        Y = np.array(Y)
        return X, Y

def load_CIFAR10(ROOT):
    """ load all of cifar """
    xs = []
    ys = []
    for b in range(1,6):
        f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
        X, Y = load_CIFAR_batch(f)
        xs.append(X)
        ys.append(Y)
    Xtr = np.concatenate(xs)
    Ytr = np.concatenate(ys)
    del X, Y
    Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
    return Xtr, Ytr, Xte, Yte
def get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=10000):
    # Load the raw CIFAR-10 data
    cifar10_dir = '../input/cifar-10-batches-py/'
    X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)

    # Subsample the data
    mask = range(num_training, num_training + num_validation)
    X_val = X_train[mask]
    y_val = y_train[mask]
    mask = range(num_training)
    X_train = X_train[mask]
    y_train = y_train[mask]
    mask = range(num_test)
    X_test = X_test[mask]
    y_test = y_test[mask]

    x_train = X_train.astype('float32')
    x_test = X_test.astype('float32')

    x_train /= 255
    x_test /= 255

    return x_train, y_train, X_val, y_val, x_test, y_test


# Invoke the above function to get our data.
x_train, y_train, x_val, y_val, x_test, y_test = get_CIFAR10_data()


print('Train data shape: ', x_train.shape)
print('Train labels shape: ', y_train.shape)
print('Validation data shape: ', x_val.shape)
print('Validation labels shape: ', y_val.shape)
print('Test data shape: ', x_test.shape)
print('Test labels shape: ', y_test.shape)

2
投票

我知道原因了!我有同样的问题,我解决了! 关键问题是关于编码方式,改代码从

dict = pickle.load(fo, encoding ='bytes')

dict = pickle.load(fo, encoding ='latin1')

0
投票

我过去也遇到过类似的问题。

我想提醒未来的读者,您可以在here 找到一个用于自动下载、提取和解析 cifar10 数据集的 python 包装器。


0
投票

这个答案是基于 Sohaib Anwaar 上面的答案,但是改变了获取数据集作为 TensorFlow

Dataset
(tf.data.Dataset) 而不是 NumPy 数组。

为什么选择 TensorFlow
Dataset

tf.data.Datasets
提供易于使用和高性能的输入管道,是访问 TensorFlow 2.x. 中任何数据集的“正确” 方式

Python 版本 >= 3.10

对于 python 版本 >= 3.10,获取 TensorFlow

Dataset
的解决方案非常简单,使用
tensorflow_datasets
.

(ds_train, ds_test), ds_info = tfds.load(
    "cifar10", 
    split=["train", "test"], 
    as_supervised=True, 
    with_info=True
)

Python版本<= 3.9.x

下载CIFAR-10数据集后,将

tar.gz
内容解压到名为
data
.

的文件夹中

已接受答案的变化

  1. load_CIFAR10
    返回 TensorFlow
    Dataset
    而不是 NumPy 数组。
  2. load_CIFAR10
    将数据集拆分为训练集、交叉验证集和测试集。
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

import os
import math
import platform
import pickle
def load_CIFAR_batch(file_path):
    """
    Load single batch of CIFAR-10 images from
    the binary file and return as a NumPy array.
    """
    with open(file_path, "rb") as f:
        data_dict = pickle.load(f, encoding="latin1")

        # Extract NumPy from dictionary.
        X = data_dict["data"]
        y = data_dict["labels"]

        # Reshape and transpose as RGB image.
        X = X.reshape(BATCH_SIZE, *input_shape, order="F").transpose((0, 2, 1, 3))
        y = np.expand_dims(y, axis=1)

        return X, y
def load_CIFAR10(cv_size=0.25):
    """
    Load all batches of CIFAR-10 images from the
    binary file and return as TensorFlow DataSet.
    """
    X_btchs = []
    y_btchs = []
    for batch in range(1, 6):
        file_path = os.path.join(ROOT, "data_batch_%d" % (batch,))
        X, y = load_CIFAR_batch(file_path)
        X_btchs.append(X)
        y_btchs.append(y)

    # Combine all batches.
    all_Xbs = np.concatenate(X_btchs)
    all_ybs = np.concatenate(y_btchs)

    # Convert Train dataset from NumPy array to TensorFlow Dataset.
    ds_all = tf.data.Dataset.from_tensor_slices((all_Xbs, all_ybs))

    # Split dataset into Train and Cross-validation sets.
    al_size = len(ds_all)
    tr_size = math.ceil((1 - cv_size) * al_size)
    print(f"Train dataset size: {tr_size}.")
    ds_tr = ds_all.take(tr_size)
    print(f"Cross-validation dataset size: {al_size - tr_size}.")
    ds_cv = ds_all.skip(tr_size)

    # Convert Test dataset from NumPy array to TensorFlow Dataset.
    X_ts, y_ts = load_CIFAR_batch(os.path.join(ROOT, "test_batch"))
    ds_ts = tf.data.Dataset.from_tensor_slices((X_ts, y_ts))
    print(f"Test dataset size {len(ds_ts)}.")

    return ds_tr, ds_cv, ds_ts
ROOT = "../data/cifar-10-batches-py/"

BATCH_SIZE = 10000
img_rows, img_cols = 32, 32
input_shape = (img_rows, img_cols, 3)

ds_tr, ds_cv, ds_ts = load_CIFAR10()

输出

Train dataset size: 37500.
Cross-validation dataset size: 12500.
Test dataset size 10000.

确认数据集

xi, yi = ds_tr.as_numpy_iterator().next()

plt.imshow(xi)
plt.title(f"Class label: {yi[0]}")
plt.show()

附言
上图是青蛙的像素化图像。


-1
投票

试试这个


def unpickle(file):
    import cPickle
    with open(file, 'rb') as fo:
        data = cPickle.load(fo)
    return data

© www.soinside.com 2019 - 2024. All rights reserved.