如何将生成器转换为 Pytorch 数据加载器?

问题描述 投票:0回答:2

我有一个可以创建合成数据的生成器。我如何将其转换为 PyTorch 数据加载器?

pytorch pytorch-lightning
2个回答
9
投票

您可以用

data.IterableDataset
包裹您的发电机:

class IterDataset(data.IterableDataset):
    def __init__(self, generator):
        self.generator = generator

    def __iter__(self):
        return self.generator

然后您可以用

data.DataLoader
包装此数据集。

这是一个展示其用途的最小示例:

>>> gen = (x for x in range(10))

>>> dataset = IterDataset(gen)
>>> for i in data.DataLoader(dataset, batch_size=2):
...    print(i)
tensor([0, 1])
tensor([2, 3])
tensor([4, 5])
tensor([6, 7])
tensor([8, 9])

3
投票

根据您提供的有限信息,这是最简单的解决方案(我假设您的生成器从噪声中创建图像,例如原始甘斯):

import torch

def get_data(batch_size, generator, latent_dim=512):
    z = torch.randn(batch_size, latent_dim)
    return genenerator(z)

def dataloader(batch_size, generator, iteration, latent_dim=512):
    for i in range(iteration):
        yield(get_data(batch_size, generator, latent_dim))

batch_size = 64
generator = GANs(...)
iteration = 100
latent_dim = 512

loader = dataloader(batch_size, generator, iteration, latent_dim)
for images in loader:
    # do something
© www.soinside.com 2019 - 2024. All rights reserved.