我在Tensorflow上构建GAN进行图像去模糊,它是DeblurGANv2的实现。我将GAN设置为具有两个输入,一批模糊的图像和一批清晰的图像。遵循这一行,我将输入设计为带有两个键['sharp', 'blur']
的Python字典,每个键都有一个形状为[batch_size, 512, 512, 3]
的张量,这使将模糊图像批处理轻松馈送到生成器,然后馈给输出变得容易生成器的图像和清晰的图像批给鉴别器。
基于最后的要求,我创建了一个tf.data.Dataset
,该输出精确地输出该命令,该命令包含两个张量,每个张量及其批处理尺寸。这与我的GAN实施相得益彰,一切正常且运行顺利。
因此请记住,我的输入不是张量,而是没有批处理维的python dict,这与稍后解释我的问题有关。
最近,我决定使用Tensorflow分配策略来增加对分布式培训的支持。 Tensorflow的此功能允许将培训分布在多个设备上,包括在多台机器上。某些实现中有一个功能,例如MirroredStrategy
,它将输入张量,将其张成相等的部分并将每个切片馈送到不同的设备,这意味着,如果批处理大小为16和4个GPU ,每个GPU将结束一个本地批处理的4个数据点,在这之后,汇总结果和与我的问题无关的其他内容就有些神奇了。
您已经注意到,对于将张量作为输入,或者至少具有外部批处理尺寸的某种输入,而我拥有的是Python字典,在输入中具有批处理尺寸,对于分配策略至关重要内部字典张量值。这是一个很大的问题,我当前的实现与分布式培训不兼容。
我一直在寻找解决方法,但是我不能很好地解决这个问题,也许只是将输入的shape=[batch_size, 2, 512, 512, 3]
张量切成张量?不确定现在才想到这大声笑。无论如何,我都觉得这很模棱两可,我无法区分这两个输入,至少在字典键不清晰的情况下。编辑:此解决方案的问题是使我的数据集转换非常昂贵,因此使数据集吞吐速度变慢,考虑到这是图像加载管道,这是重点。
也许我对分布式策略的工作方式的解释并不是最严格的,如果我看不到有什么办法可以纠正我的话。
PD:这不是错误问题或代码错误,主要不是“系统设计查询”,希望这里不是非法的]]
我在Tensorflow上构建GAN进行图像去模糊,它是DeblurGANv2的实现。我将GAN设置为具有两个输入,一批模糊的图像和一批清晰的图像。 ...
而不是使用字典作为GAN的输入,您可以尝试通过以下方式映射功能,
def load_image(fileA,fileB):
imageA = tf.io.read_file(fileA)
imageA = tf.image.decode_jpeg(imageA, channels=3)
imageB = tf.io.read_file(fileB)
imageB = tf.image.decode_jpeg(imageB)
return imageA,imageB
trainA = glob.glob('blur/*.jpg')
trainB = glob.glob('sharp/*.jpg')
train_dataset = tf.data.Dataset.from_tensor_slices((trainA,trainB))
train_dataset = train_dataset.map(load_image).batch(batch_size)
#for mirrored strategy
dist_dataset = mirrored_strategy.experimental_distribute_dataset(train_dataset)