我使用两个FIFOQueue
来读取数据(输入文件和标签文件),它工作正常。但是当使用RandomShuffleQueue
时,似乎输入文件和标签文件无法对齐。
这是一个简单的例子:
使用FIFOQueue
,一切都很好
import tensorflow as tf
input_queue = tf.FIFOQueue(capacity=50, dtypes="int32", shapes=[()])
label_queue = tf.FIFOQueue(capacity=50, dtypes ="int32", shapes=[()])
input_op = input_queue.enqueue_many((range(5),))
label_op = label_queue.enqueue_many((range(5),))
input_res = input_queue.dequeue_many(10)
label_res = label_queue.dequeue_many(10)
with tf.Session() as sess:
#filled the queue
for _ in range(10):
sess.run([input_op,label_op])
print sess.run([input_res,label_res])
输入和标签数据中的顺序是匹配的。
[array([0, 1, 2, 3, 4, 0, 1, 2, 3, 4], dtype=int32),
array([0, 1, 2, 3, 4, 0, 1, 2, 3, 4], dtype=int32)]
但对于RandomShuffleQueue
input_queue = tf.RandomShuffleQueue(capacity=50, min_after_dequeue=2,
dtypes="int32", shapes=[()])
label_queue = tf.RandomShuffleQueue(capacity=50, min_after_dequeue=2,
dtypes ="int32", shapes=[()])
订单更改如下:
[array([1, 1, 4, 3, 0, 0, 1, 4, 2, 0], dtype=int32),
array([3, 2, 0, 0, 1, 1, 4, 0, 4, 3], dtype=int32)]
你可以看到,它没有对齐。如何使它工作?
您应该使用single Queue
同时读取输入和标签,而不是管理两个单独的队列,如下所示:
input_queue = tf.RandomShuffleQueue(capacity=50, min_after_dequeue=2,
dtypes=[tf.int32, tf.int32], shapes=[[],[]])
data = tf.range(5)
label = tf.range(5)
enqueue_op = input_queue.enqueue_many([data, label])
dequeue = input_queue.dequeue_many(10)
with tf.Session() as sess:
#filled the queue
for _ in range(10):
sess.run(enqueue_op)
print (sess.run(dequeue))
输出:
[array([1, 3, 0, 1, 1, 2, 1, 4, 1, 3], dtype=int32),
array([1, 3, 0, 1, 1, 2, 1, 4, 1, 3], dtype=int32)]
我在API中找到了一个有用的选项,即seed
input_queue = tf.RandomShuffleQueue(capacity=50, min_after_dequeue=2,
dtypes="int32", shapes=[()], seed=1234)
label_queue = tf.RandomShuffleQueue(capacity=50, min_after_dequeue=2,
dtypes="int32", shapes=[()], seed=1234)
[array([4,2,4,2,2,1,2,3,4,3],dtype = int32),array([4,2,4,2,2,1,2,3,4] ,3],dtype = int32)]