[我正在使用Dataset API构建数据管道,但是当我训练多个GPU并在输入函数中返回dataset.make_one_shot_iterator().get_next()
时,我得到了
ValueError: dataset_fn() must return a tf.data.Dataset when using a tf.distribute.Strategy
我可以遵循错误消息并直接返回数据集,但我不了解iterator().get_next()
的目的以及它如何在单GPU或多GPU上进行训练。
...
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size = batch_size)
dataset = dataset.cache()
dataset = dataset.prefetch(buffer_size=None)
return dataset.make_one_shot_iterator().get_next()
return _input_fn
[将tf.data
与分配策略一起使用(可以与Keras和tf.Estimator
一起使用时,您的输入fn应该返回tf.data.Dataset
:]