如何对齐两个tf.RandomShuffleQueue?

问题描述 投票:0回答:2

我使用两个FIFOQueue来读取数据(输入文件和标签文件),它工作正常。但是当使用RandomShuffleQueue时,似乎输入文件和标签文件无法对齐。

这是一个简单的例子:

使用FIFOQueue,一切都很好

import tensorflow as tf

input_queue = tf.FIFOQueue(capacity=50, dtypes="int32", shapes=[()])
label_queue = tf.FIFOQueue(capacity=50, dtypes ="int32", shapes=[()])

input_op = input_queue.enqueue_many((range(5),))
label_op = label_queue.enqueue_many((range(5),))

input_res = input_queue.dequeue_many(10)
label_res = label_queue.dequeue_many(10)

with tf.Session() as sess:
    #filled the queue
    for _ in range(10):
        sess.run([input_op,label_op])
    print sess.run([input_res,label_res])

输入和标签数据中的顺序是匹配的。

[array([0, 1, 2, 3, 4, 0, 1, 2, 3, 4], dtype=int32), 
 array([0, 1, 2, 3, 4, 0, 1, 2, 3, 4], dtype=int32)]

但对于RandomShuffleQueue

input_queue = tf.RandomShuffleQueue(capacity=50, min_after_dequeue=2, 
                                    dtypes="int32", shapes=[()])
label_queue = tf.RandomShuffleQueue(capacity=50, min_after_dequeue=2, 
                                    dtypes ="int32", shapes=[()])

订单更改如下:

[array([1, 1, 4, 3, 0, 0, 1, 4, 2, 0], dtype=int32), 
 array([3, 2, 0, 0, 1, 1, 4, 0, 4, 3], dtype=int32)]

你可以看到,它没有对齐。如何使它工作?

python tensorflow
2个回答
2
投票

您应该使用single Queue同时读取输入和标签,而不是管理两个单独的队列,如下所示:

input_queue = tf.RandomShuffleQueue(capacity=50, min_after_dequeue=2, 
                                dtypes=[tf.int32, tf.int32], shapes=[[],[]])
data = tf.range(5)
label = tf.range(5)
enqueue_op = input_queue.enqueue_many([data, label])
dequeue = input_queue.dequeue_many(10)
with tf.Session() as sess:
   #filled the queue
   for _ in range(10):
      sess.run(enqueue_op)
   print (sess.run(dequeue))

输出:

[array([1, 3, 0, 1, 1, 2, 1, 4, 1, 3], dtype=int32), 
array([1, 3, 0, 1, 1, 2, 1, 4, 1, 3], dtype=int32)]

0
投票

我在API中找到了一个有用的选项,即seed

input_queue = tf.RandomShuffleQueue(capacity=50, min_after_dequeue=2, 
                                    dtypes="int32", shapes=[()], seed=1234)
label_queue = tf.RandomShuffleQueue(capacity=50, min_after_dequeue=2, 
                                    dtypes="int32", shapes=[()], seed=1234)

[array([4,2,4,2,2,1,2,3,4,3],dtype = int32),array([4,2,4,2,2,1,2,3,4] ,3],dtype = int32)]

© www.soinside.com 2019 - 2024. All rights reserved.