Tensorflow batch_join的allow_smaller_final_batch不起作用?

问题描述 投票:2回答:1

我正在使用tenosrlfow队列处理我的数据,我需要获得大小小于批量大小的最终批次,但我只能获得5个批量大小,最终批次无法获得。我不明白这是什么问题。

data = np.arange(105)
data_placeholder = tf.placeholder(dtype=tf.int64, shape=[None,])

queue = tf.FIFOQueue(capacity=200,dtypes=tf.int64,shapes=())
enqueue_op = queue.enqueue_many([data_placeholder])

data_list = []
data_ = queue.dequeue()

data_list.append([data_])
batch_data = tf.train.batch_join(data_list,batch_size=20, capacity=100 ,allow_smaller_final_batch=True)

sess = tf.Session()

coord = tf.train.Coordinator()
tf.train.start_queue_runners(sess,coord)

step = 0
under = 0
uper = 0
enqueu_step = len(data)//20 + 1
while step < enqueu_step:
    uper = uper + 20
    sess.run(enqueue_op, feed_dict={data_placeholder:data[under:uper]})
    print("enque step=%d/%d %d-%d" %(step, enqueu_step,under, uper))
    step = step + 1
    under = uper
i = 0
while i < enqueu_step:
    _data = sess.run(batch_data)
    print("setp=%d/%d shape=%s" % (i, enqueu_step,_data.shape))

    i = i + 1
print("end")
python tensorflow
1个回答
1
投票

我没有检查过你的整个代码,但是如果我做得对,你想得到所有的样品,即使最后一批比其余的小,对吧?

好吧,使用这个玩具示例,8个样本并使用3个批次:

import tensorflow as tf
import numpy as np

num_samples = 8
batch_size = 3
capacity = num_samples % batch_size # set the capacity to the actual remaining samples
data = np.arange(1, num_samples+1)
data_input = tf.constant(data)

batch = tf.train.batch([data_input], enqueue_many=True, batch_size=batch_size, capacity=capacity, allow_smaller_final_batch=True)

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for i in range(2):
        print(i, sess.run([batch]))
    coord.request_stop()
    coord.join(threads)
    # this should follow a closed queue
    print(i+1, sess.run([batch]))

结果:

0 [数组([1,2,3])]

1 [数组([4,5,6])]

2 [array([7,8])]

这里的重要参数是enqueue_many,以便将每个数字视为单独的数字,并将capacity设置为实际的剩余样本(例如,这里是2)。如果capacity设置为1,您将获得1sample and if it's 3 you will miss theallow_smaller_final_batch`标志效果,因为它将返回3个样本(从头开始的最后一个)。

希望这澄清你应该使用allow_smaller_final_batch参数的方式。

© www.soinside.com 2019 - 2024. All rights reserved.