我想按标签中的数值对数据集进行排序。
有没有来自pytorch的功能来有效地处理这个问题?
我的数据集type()
来自:
<class 'torchvision.datasets.mnist.MNIST'>
没有通用的方法来有效地执行此操作,因为数据集类仅实现了__getitem__
和__len__
方法,并且不一定具有关于标签的任何“存储”信息。
但是,对于MNIST dataset类,您可以从标签列表中对数据集进行排序。
例如,当您要列出标签为5的索引时。
mnist = torchvision.datasets.mnist.MNIST("/")
labels = mnist.train_labels
fives = (labels == 5).nonzero()