用tf.data API替换tf.placeholder和feed_dict

问题描述 投票:13回答:2

我有一个现有的TensorFlow模型,它使用tf.placeholder作为模型输入,使用tf.Session()的feed_dict参数运行以输入数据。以前,整个数据集都被读入内存并以这种方式传递。

我想使用更大的数据集,并利用tf.data API的性能改进。我已经从中定义了一个tf.data.TextLineDataset和一次性迭代器,但是我很难弄清楚如何将数据导入模型来训练它。

起初我尝试将feed_dict定义为从占位符到iterator.get_next()的字典,但这给了我一个错误,说明feed的值不能是tf.Tensor对象。更多的挖掘让我明白这是因为iterator.get_next()返回的对象已经是图形的一部分了,不像你将它提供给feed_dict - 并且我不应该尝试使用feed_dict表现原因。

所以现在我已经摆脱了输入tf.placeholder并将其替换为定义我的模型的类的构造函数的参数;在我的训练代码中构建模型时,我将iterator.get_next()的输出传递给该参数。这似乎有点笨拙,因为它打破了模型定义与数据集/培训程序之间的分离。我现在收到一个错误,说Tensor代表(我相信)我的模型的输入必须来自与iterator.get_next()的Tensor相同的图形。

我是否采用这种方法走上了正确的轨道,并且只是在设置图形和会话或其他类似方面做错了什么? (数据集和模型都在会话之外初始化,并且在我尝试创建之前发生错误。)

或者我完全偏离这个并且需要做一些不同的事情,比如使用Estimator API并在输入函数中定义所有东西?

以下是一些演示最小示例的代码:

import tensorflow as tf
import numpy as np

class Network:
    def __init__(self, x_in, input_size):
        self.input_size = input_size
        # self.x_in = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size))  # Original
        self.x_in = x_in
        self.output_size = 3

        tf.reset_default_graph()  # This turned out to be the problem

        self.layer = tf.layers.dense(self.x_in, self.output_size, activation=tf.nn.relu)
        self.loss = tf.reduce_sum(tf.square(self.layer - tf.constant(0, dtype=tf.float32, shape=[self.output_size])))

data_array = np.random.standard_normal([4, 10]).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices(data_array).batch(2)

model = Network(x_in=dataset.make_one_shot_iterator().get_next(), input_size=dataset.output_shapes[-1])
python tensorflow tensorflow-datasets
2个回答
7
投票

我也花了一些时间来解决问题。你走在正确的轨道上。整个数据集定义只是图表的一部分。我通常将它创建为与我的Model类不同的类,并将数据集传递给Model类。我指定要在命令行上加载的数据集类,然后动态加载该类,从而模块化地解耦数据集和图形。

请注意,您可以(并且应该)命名数据集中的所有张量,当您通过所需的各种转换传递数据时,它确实有助于使事情易于理解。

您可以编写简单的测试用例,从iterator.get_next()中提取样本并显示它们,您将拥有像sess.run(next_element_tensor),没有feed_dict之类的东西,正如您所记录的那样。

一旦你了解它,你可能会开始喜欢数据集输入管道。它迫使你很好地模块化你的代码,并强制它进入一个易于单元测试的结构。

请务必阅读开发者指南,其中有大量示例:

https://www.tensorflow.org/programmers_guide/datasets

我要注意的另一件事是使用此管道处理火车和测试数据集是多么容易。这很重要,因为您经常对训练数据集执行数据增强,而这些数据集并未在测试数据集上执行,from_string_handle允许您这样做,并在上面的指南中有清楚的描述。


2
投票

从我给出的原始代码中的模型构造函数中的行tf.reset_default_graph()引起了它。删除修复它。

© www.soinside.com 2019 - 2024. All rights reserved.