将自定义 PyTorch 数据集拆分为训练加载器和验证加载器:即使数据集被拆分,两者的长度也相同?

问题描述 投票:0回答:2

我正在尝试将 Pytorch 自定义数据集 (MNIST) 拆分为训练集和验证集,如下所示:

def get_train_valid_splits(data_dir,
                           batch_size,
                           random_seed=1,
                           valid_size=0.2,
                           shuffle=True,
                           num_workers=4,
                           pin_memory=False):

    normalize = transforms.Normalize((0.1307,), (0.3081,))  # MNIST

    # define transforms
    valid_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])

        train_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])

    # load the dataset
    train_dataset = datasets.MNIST(root=data_dir, train=True,
                download=True, transform=train_transform)

    valid_dataset = datasets.MNIST(root=data_dir, train=True,
                download=True, transform=valid_transform)

    dataset_size = len(train_dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(valid_size * dataset_size))

    
    if shuffle == True:
        np.random.seed(random_seed)
        np.random.shuffle(indices)
    

    train_idx, valid_idx = indices[split:], indices[:split]

    train_sampler = sampler.SubsetRandomSampler(train_idx)
    valid_sampler = sampler.SubsetRandomSampler(valid_idx)

    print(len(train_sampler))
    print(len(valid_sampler))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                    batch_size=batch_size, sampler=train_sampler,
                    num_workers=num_workers, pin_memory=pin_memory)

    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                    batch_size=batch_size, sampler=valid_sampler,
                    num_workers=num_workers, pin_memory=pin_memory)

    print(len(train_loader.dataset))
    print(len(valid_loader.dataset))

    return (train_loader, valid_loader)

调用该函数后,我注意到采样索引的结果看起来是正确的,48000 和 12000:

print(len(train_sampler))
print(len(valid_sampler))

但是当我查看与 train_loader 和 valid_loader 相关的数据集的长度时:

print(len(train_loader.dataset))
print(len(valid_loader.dataset))

我得到的两者长度相同:60000!知道这是怎么回事吗?为什么它给两个相同的长度,即使我明确地按索引分割它?

python validation pytorch mnist dataloader
2个回答
0
投票

这是因为数据加载器不会修改您传递给它的数据集,而是在您尝试通过迭代访问数据时“应用”批量大小、采样器等内容。您的问题是

len(loader.dataset)
,当您真正想要
len(loader)
时,它会为您提供未修改的所提供数据集的长度,这是“应用”批量大小和采样器等内容后的数据集长度。

import torch
import numpy as np

dataset = np.random.rand(100,200)
sampler = torch.utils.data.SubsetRandomSampler(list(range(70)))

loader = torch.utils.data.DataLoader(dataset, sampler=sampler)
print(len(loader)) 
>>> 70
print(len(loader.dataset))
>>> 100

注意: len 的结果会受到批量大小的影响:

# with batch size
loader = torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=2)
print(len(loader)) 
>>> 35
print(len(loader.dataset))
>>> 100

0
投票

train_loader
valid_loader
的长度相同是因为您对
train_dataset
valid_dataset
使用了相同的数据。

你想要

valid_dataset = datasets.MNIST(root=data_dir, train=False,
                               download=True, transform=valid_transform)

(不是

train=True
)下载验证集。

© www.soinside.com 2019 - 2024. All rights reserved.