从现有的数据集Torchvision创建降低数据集

问题描述 投票:0回答:1

我们都知道普通MNIST数据集,包括在torchvision.datasets包。想象一下,我想创建一个仅包含1和0只进行分类这两个数字,而不是所有10个值数据集的简化版本。

我已经看到了自定义的数据集可以继承所需的数据集类中创建,所以__getitem__,其给定索引处返回的项目。所以,我已经做到了这一点:

class MNIST01(MNIST):
    def __getitem__(self, idx):
        image, label = super().__getitem__(idx)
        if label.item() <= 1:
            return image, label
        else:
            return None

问题是,它似乎是因为它的要求是我不能返回无值“含有张量,数量,类型的字典或列表,发现了类‘NoneType’”。

有一个简单的方法以类似的方式来获得这种数据集的简化版本容易吗?

python dataset torchvision
1个回答
0
投票

我终于成功地应对NoneType问题。保持在这个问题中定义的功能。

class MNIST01(MNIST):
    def __getitem__(self, idx):
        features, target = super(MNIST01, self).__getitem__(idx)
        if target.item() <= 1:
            return features, target

现在,我们需要定义一个自定义collate function collate_fn我们的DataLoader,其处理样品的列表形成一批。在此功能中,我们可以应用过滤器来处理Nonevalues和忽略它们。

from torch.utils.data.dataloader import default_collate

def filter_collate(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return default_collate(batch)

然后,我们只需要这个功能传递给DataLoader

from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, collate_fn=filter_collate, **kwargs)
test_loader = DataLoader(test_dataset, collate_fn=filter_collate, **kwargs)

第2版

访问数据时更容易比第一个,避免一些问题。只是直接过滤train_datatrain_label属性(和相应的测试集)从MNIST类的instanciation。

train_dataset.train_data = train_dataset.train_data[train_dataset.train_labels <= 1]
train_dataset.train_labels = train_dataset.train_labels[train_dataset.train_labels <= 1]
© www.soinside.com 2019 - 2024. All rights reserved.