是否有一种方法可以从CIFAR-10训练数据集中仅提取所需的类?

问题描述 投票:0回答:1

我想做的事情看起来很简单,但是没有用。我想对每类图像(矩阵)执行某些操作,因此我首先必须从加扰的批次中提取每个图像。

from tensorflow.keras import datasets
import numpy as np

(train_images, train_labels), (test_images, test_labels)= datasets.cifar10.load_data()
print(len(train_images))
print(len(train_images))
train_images[train_labels==6]

这是错误。当然,这是因为图像矩阵的形状(50000,32,32,3)。即使图像和标签的长度相同,都为50000,但python无法以某种方式使用矩阵作为1项进行过滤。帮助将非常受欢迎。

50000
50000


---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-170-029cc3d4f0a9> in <module>
      5 
      6 
----> 7 train_images[train_labels==6]

IndexError: boolean index did not match indexed array along dimension 1; dimension is 32 but corresponding boolean dimension is 1

我想做的事情看起来很简单,但是没有用。我想对每类图像(矩阵)执行某些操作,因此我首先必须从加扰的批次中提取每个图像。来自...

python list tensorflow arraylist conv-neural-network
1个回答
0
投票

这里的问题是train_labels具有形状(50000,1),因此当您对其进行索引时,numpy尝试将其用作二维。这是一个简单的解决方法。

© www.soinside.com 2019 - 2024. All rights reserved.