如何在tf.data.Dataset中填充固定的BATCH_SIZE?

问题描述 投票:4回答:2

我有一个包含11个样本的数据集。当我选择BATCH_SIZE为2时,以下代码将出错:

dataset = tf.contrib.data.TFRecordDataset(filenames) 
dataset = dataset.map(parser)
if shuffle:
    dataset = dataset.shuffle(buffer_size=128)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(count=1)

问题在于dataset = dataset.batch(batch_size),当Dataset循环进入最后一批时,剩余的样本数量只有1,所以有没有办法从之前访问过的样本中随机挑选一个并生成最后一批?

tensorflow tensorflow-datasets
2个回答
7
投票

@mining通过填充文件名来提出解决方案。

另一种解决方案是使用tf.contrib.data.batch_and_drop_remainder。这将使用固定的批处理大小批处理数据并删除最后一个较小的批处理。

在您的示例中,使用11个输入和批量大小为2,这将产生5批2个元素。

以下是文档中的示例:

dataset = tf.data.Dataset.range(11)
batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(2))

3
投票

你可以在你的drop_remainder=True电话中设置batch

dataset = dataset.batch(batch_size, drop_remainder=True)

来自documentation

drop_remainder :(可选。)一个tf.bool标量tf.Tensor,表示在少于batch_size元素的情况下是否应删除最后一批;默认行为是不删除较小的批处理。

© www.soinside.com 2019 - 2024. All rights reserved.