使用队列运行程序为测试/验证数据添加tf.placeholder的标准方法是什么

问题描述 投票:0回答:1

我有一个训练模型,可以获取所有训练数据并创建一个队列:

x = tf.placeholder(tf.float32, (N, steps, size), name='x')
y = tf.placeholder(tf.float32, (N, out_size), name='y')
var_x = tf.Variable(x, trainable=False, collections=[])
var_y = tf.Variable(y, trainable=False, collections=[])
x_queue, y_queue = tf.train.slice_input_producer([var_x, var_y], 
                                                 num_epochs=10, shuffle=True)
x_batch, y_batch = tf.train.batch([x_queue, y_queue], batch_size=batch_size)

...

with tf.Session() as sess:
   sess.run(var_x, feed_dict={x: X})
   sess.run(var_y, feed_dict={y: Y})
   coord = tf.train.Coordinator()
   threads = tf.train.start_queue_runners(sess=sess, coord=coord)

...

这个网络工作正常,我能够训练它。在这个网络中,我想添加一个新的占位符来获取我的测试数据:

x_test = tf.placeholder(tf.float32, (1, steps, size), name='x_test')

我想使用tf.cond来控制哪个占位符被输入:

rnn_inputs = tf.cond(is_train, lambda: x, lambda: x_test)

但是,很多帖子都说使用tf.cond效率不高。此外,使用新的占位符来测试/验证数据是一个问题,因为即使我正在尝试训练模型,tensorflow也会抛出一个错误,要求我将数据输入其中。

有这样做的标准方法吗?

python tensorflow neural-network
1个回答
1
投票

最有效的方法是使用迭代器来提供数据。您可以创建一个句柄来指定是从列车还是验证数据集提供。以下是https://www.tensorflow.org/programmers_guide/datasets的一个例子。我发现这种方法很有效

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)

# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the output_types and output_shapes properties of either
# training_dataset or validation_dataset here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()

# The Iterator.string_handle() method returns a tensor that can be evaluated
# and used to feed the handle placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())

# Loop forever, alternating between training and validation.
while True:
  # Run 200 steps using the training dataset. Note that the training dataset is
  # infinite, and we resume from where we left off in the previous `while` loop
  # iteration.
  for _ in range(200):
    sess.run(next_element, feed_dict={handle: training_handle})

  # Run one pass over the validation dataset.
  sess.run(validation_iterator.initializer)
  for _ in range(50):
    sess.run(next_element, feed_dict={handle: validation_handle})
© www.soinside.com 2019 - 2024. All rights reserved.