如何根据标签分割pytorch数据集?

问题描述 投票:0回答:2

我需要按类别拆分 CIFAR10 数据集,以便我可以为每个类别创建具有相同数量样本的较小样本。

我怎样才能最好地实现这一点?

pytorch dataset
2个回答
0
投票
import numpy
sorted_by_value = [0]*10
for i in range(10):
  sorted_by_value[i] =(train.data[numpy.where(numpy.array(train.targets) == i)])
  numpy.random.shuffle(sorted_by_value[i])

对于任何数据集,你可以将 10 替换为类别数,然后就可以了。


0
投票

您可以使用 torch.utils.data.Subset 来实现此目的。下面是如何制作仅包含数字 0 到 4 的 MNIST 训练集子集的示例:

transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)

classes = [0, 1, 2, 3, 4]
indices = [i for i, (k, v) in enumerate(trainset) if v in classes]
trainset = torch.utils.data.Subset(trainset, indices)
© www.soinside.com 2019 - 2024. All rights reserved.