我有一个可以创建合成数据的生成器。我如何将其转换为 PyTorch 数据加载器?
data.IterableDataset
包裹您的发电机:
class IterDataset(data.IterableDataset):
def __init__(self, generator):
self.generator = generator
def __iter__(self):
return self.generator
data.DataLoader
包装此数据集。
这是一个展示其用途的最小示例:
>>> gen = (x for x in range(10))
>>> dataset = IterDataset(gen)
>>> for i in data.DataLoader(dataset, batch_size=2):
... print(i)
tensor([0, 1])
tensor([2, 3])
tensor([4, 5])
tensor([6, 7])
tensor([8, 9])
根据您提供的有限信息,这是最简单的解决方案(我假设您的生成器从噪声中创建图像,例如原始甘斯):
import torch
def get_data(batch_size, generator, latent_dim=512):
z = torch.randn(batch_size, latent_dim)
return genenerator(z)
def dataloader(batch_size, generator, iteration, latent_dim=512):
for i in range(iteration):
yield(get_data(batch_size, generator, latent_dim))
batch_size = 64
generator = GANs(...)
iteration = 100
latent_dim = 512
loader = dataloader(batch_size, generator, iteration, latent_dim)
for images in loader:
# do something