PyTorch DataLoader随机播放

问题描述 投票:0回答:1

我做了一个实验,但没有得到我期望的结果。

第一部分,我正在使用

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 
                                          shuffle=False, num_workers=0)

我在训练模型之前将trainloader.dataset.targets保存到变量a,将trainloader.dataset.data保存到变量b。然后,我使用trainloader训练模型。训练完成后,我将trainloader.dataset.targets保存到变量c,将trainloader.dataset.data保存到变量d。最后,我检查a == cb == d,它们都给出了True,这是预期的,因为DataLoader的洗牌参数是False

第二部分,我正在使用

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, 
                                          shuffle=True, num_workers=0)

我在训练模型之前将trainloader.dataset.targets保存到变量e,将trainloader.dataset.data保存到变量f。然后,我使用trainloader训练模型。训练完成后,我将trainloader.dataset.targets保存到变量g,将trainloader.dataset.data保存到变量h。我希望自e == gf == hFalse都为shuffle=True,但是它们再次给出TrueDataLoader类的定义中我缺少什么?

python neural-network pytorch shuffle training-data
1个回答
1
投票

我相信直接存储在trainloader.dataset.data或.target中的数据不会被改组,仅当将DataLoader称为生成器或迭代器时才对数据进行改组。

您可以通过多次执行next(iter(trainloader))来进行检查,而不会混洗和混洗,它们应该给出不同的结果

import torch
import torchvision

transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        ])
MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
                                           transform = transform)
dataLoader = torch.utils.data.DataLoader(MNIST_dataset,
                                         batch_size = 128,
                                         shuffle = False,
                                         num_workers = 10)
target = dataLoader.dataset.targets


MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
                                           transform = transform)

dataLoader_shuffled= torch.utils.data.DataLoader(MNIST_dataset,
                                         batch_size = 128,
                                         shuffle = True,
                                         num_workers = 10)

target_shuffled = dataLoader_shuffled.dataset.targets

print(target == target_shuffled)

_, target = next(iter(dataLoader));
_, target_shuffled = next(iter(dataLoader_shuffled))

print(target == target_shuffled)

这将给:

tensor([True, True, True,  ..., True, True, True])
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,  True,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True, False, False, False, False, False,
        False,  True, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True,  True, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False,  True, False, False,  True, False,
        False, False, False, False, False, False, False, False])

但是,存储在数据和目标中的数据和标签是固定列表,并且由于您尝试直接访问它,因此它们不会被随机播放。

© www.soinside.com 2019 - 2024. All rights reserved.