如何在pytorch中对数据集进行排序

问题描述 投票:0回答:1

我想按标签中的数值对数据集进行排序。

有没有来自pytorch的功能来有效地处理这个问题?

我的数据集type()来自:

 <class 'torchvision.datasets.mnist.MNIST'>
python machine-learning pytorch mnist
1个回答
0
投票

没有通用的方法来有效地执行此操作,因为数据集类仅实现了__getitem____len__方法,并且不一定具有关于标签的任何“存储”信息。

但是,对于MNIST dataset类,您可以从标签列表中对数据集进行排序。

例如,当您要列出标签为5的索引时。

mnist = torchvision.datasets.mnist.MNIST("/")
labels = mnist.train_labels
fives = (labels == 5).nonzero()

© www.soinside.com 2019 - 2024. All rights reserved.