如何使用tf.data的可初始化迭代器和可重新初始化的插入器,并将数据馈送到估计器api?

问题描述 投票:5回答:2

[所有官方的google官方教程都对所有estimator api实现使用一次快照迭代器,我找不到任何有关如何使用tf.data的可初始化迭代器和可重新初始化迭代器而不是一个快照迭代器的文档。

[有人可以告诉我如何使用tf.data的可初始化迭代器和可重新初始化插入器在train_data和test_data之间进行切换。我们需要运行一个会话来使用feed dict,并在可初始化的迭代器,其低级api及其令人困惑的如何使用它的方法中切换数据集,该方法是估算器api体系结构的一部分

PS:我确实发现Google提到了“注意:当前,单次迭代器是唯一可与Estimator一起使用的类型。”

但是社区内部有什么工作吗?还是出于某些原因我们应该只使用一个镜头迭代器

python tensorflow tensorflow-datasets tensorflow-estimator
2个回答
5
投票
这里有个简单的例子,您可以适应自己的需求:

class IteratorInitializerHook(tf.train.SessionRunHook): def __init__(self): super(IteratorInitializerHook, self).__init__() self.iterator_initializer_func = None # Will be set in the input_fn def after_create_session(self, session, coord): self.iterator_initializer_func(session) def get_inputs(X, y): iterator_initializer_hook = IteratorInitializerHook() def input_fn(): X_pl = tf.placeholder(X.dtype, X.shape) y_pl = tf.placeholder(y.dtype, y.shape) dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl)) dataset = ... ... iterator = dataset.make_initializable_iterator() next_example, next_label = iterator.get_next() iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer, feed_dict={X_pl: X, y_pl: y}) return next_example, next_label return input_fn, iterator_initializer_hook ... train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train) test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test) ... estimator.train(input_fn=train_input_fn, hooks=[train_iterator_initializer_hook]) estimator.evaluate(input_fn=test_input_fn, hooks=[test_iterator_initializer_hook])

这是我在blogpostSebastian Pölsterl中找到的代码的修改版本。在“通过数据集API将数据馈送到估算器”部分下查看。


0
投票
© www.soinside.com 2019 - 2024. All rights reserved.