Tensorflow 2.0:在多输入情况下构造tf.data.Dataset输出的最佳方法

问题描述 投票:0回答:1

我在Tensorflow上构建GAN进行图像去模糊,它是DeblurGANv2的实现。我将GAN设置为具有两个输入,一批模糊的图像和一批清晰的图像。遵循这一行,我将输入设计为带有两个键['sharp', 'blur']的Python字典,每个键都有一个形状为[batch_size, 512, 512, 3]的张量,这使将模糊图像批处理轻松馈送到生成器,然后馈给输出变得容易生成器的图像和清晰的图像批给鉴别器。

基于最后的要求,我创建了一个tf.data.Dataset,该输出精确地输出该命令,该命令包含两个张量,每个张量及其批处理尺寸。这与我的GAN实施相得益彰,一切正常且运行顺利。

因此请记住,我的输入不是张量,而是没有批处理维的python dict,这与稍后解释我的问题有关。

最近,我决定使用Tensorflow分配策略来增加对分布式培训的支持。 Tensorflow的此功能允许将培训分布在多个设备上,包括在多台机器上。某些实现中有一个功能,例如MirroredStrategy,它将输入张量,将其张成相等的部分并将每个切片馈送到不同的设备,这意味着,如果批处理大小为16和4个GPU ,每个GPU将结束一个本地批处理的4个数据点,在这之后,汇总结果和与我的问题无关的其他内容就有些神奇了。

您已经注意到,对于将张量作为输入,或者至少具有外部批处理尺寸的某种输入,而我拥有的是Python字典,在输入中具有批处理尺寸,对于分配策略至关重要内部字典张量值。这是一个很大的问题,我当前的实现与分布式培训不兼容。

我一直在寻找解决方法,但是我不能很好地解决这个问题,也许只是将输入的shape=[batch_size, 2, 512, 512, 3]张量切成张量?不确定现在才想到这大声笑。无论如何,我都觉得这很模棱两可,我无法区分这两个输入,至少在字典键不清晰的情况下。编辑:此解决方案的问题是使我的数据集转换非常昂贵,因此使数据集吞吐速度变慢,考虑到这是图像加载管道,这是重点。

也许我对分布式策略的工作方式的解释并不是最严格的,如果我看不到有什么办法可以纠正我的话。

PD:这不是错误问题或代码错误,主要不是“系统设计查询”,希望这里不是非法的]]

我在Tensorflow上构建GAN进行图像去模糊,它是DeblurGANv2的实现。我将GAN设置为具有两个输入,一批模糊的图像和一批清晰的图像。 ...

python tensorflow tensorflow2.0 tensorflow-datasets
1个回答
0
投票

而不是使用字典作为GAN的输入,您可以尝试通过以下方式映射功能,

def load_image(fileA,fileB):
    imageA = tf.io.read_file(fileA)
    imageA = tf.image.decode_jpeg(imageA, channels=3)

    imageB = tf.io.read_file(fileB)
    imageB = tf.image.decode_jpeg(imageB)
    return imageA,imageB

trainA = glob.glob('blur/*.jpg')
trainB = glob.glob('sharp/*.jpg')
train_dataset = tf.data.Dataset.from_tensor_slices((trainA,trainB))
train_dataset = train_dataset.map(load_image).batch(batch_size)

#for mirrored strategy

dist_dataset = mirrored_strategy.experimental_distribute_dataset(train_dataset)
© www.soinside.com 2019 - 2024. All rights reserved.