创建 pyTorch 测试数据集(无标签)

问题描述 投票:0回答:2

我为我的训练数据创建了一个 pyTorch 数据集,其中包含特征和标签,以便能够使用 this 教程来利用 pyTorch DataLoader。这对于我的训练数据效果很好,但在加载测试 csv 文件时出现错误 (

KeyError: "['label'] not found in axis"
),除了没有“标签”列之外,该文件是相同的。

如果有帮助,预期的输入 csv 文件是 csv 文件中的 MNIST 数据,其中具有 28*28 特征列。

import torch

class mnist(torch.utils.data.Dataset):
    
    def __init__(self, csv_file):
        self.train = pd.read_csv(csv_file)
        self.train_x = self.train.drop("label", axis=1)
    
    def __len__(self):
        return len(self.train)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        if isinstance(idx, list):
            idx_len = len(idx)
        else:
            idx_len = 1
        
        X = np.asarray(self.train_x.iloc[idx], dtype=np.float32)
        X = np.reshape(X, (1,28,28))
        y = np.asarray(self.train.iloc[idx]['label'])
        
        sample = {'X': X, 'y':y}
        
        return torch.from_numpy(sample['X']), torch.from_numpy(sample['y'])
python pytorch pytorch-dataloader
2个回答
0
投票

您应该能够使用这两种数据:

import torch

class mnist(torch.utils.data.Dataset):
    
    def __init__(self, csv_file):
        self.train = pd.read_csv(csv_file)

        self.training = "label" in self.train.columns
        self.train_x = self.train if not self.training else self.train.drop("label", axis=1)
    
    def __len__(self):
        return len(self.train)
    
    def __getitem__(self, idx):
        ...
        
        X = np.asarray(self.train_x.iloc[idx], dtype=np.float32)
        X = np.reshape(X, (1,28,28))
        if not self.training:
            return torch.from_numpy(X])

        y = np.asarray(self.train.iloc[idx]['label'])

        sample = {'X': X, 'y':y}
        return torch.from_numpy(sample['X']), torch.from_numpy(sample['y'])

0
投票

这里我使用了 torchvision.datasets 中的 cifar10 数据集。 &我会忽略它的标签

import torch
class IgnoreLabelDataset(torch.utils.data.Dataset):
    def __init__(self, orig):
        self.orig = orig

    def __getitem__(self, index):
        return self.orig[index][0]

    def __len__(self):
        return len(self.orig)

import torchvision.datasets as dset
import torchvision.transforms as transforms

cifar = dset.CIFAR10(root='data/', download=True,
                         transform=transforms.Compose([
                             transforms.Scale(32),
                             transforms.ToTensor(),
                             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                         ])
)

dataset_with_no_labels = IgnoreLabelDataset(cifar)

最后

dataset_with_no_labels
是我的没有标签的数据集对象

© www.soinside.com 2019 - 2024. All rights reserved.