我有一个包含11个样本的数据集。当我选择BATCH_SIZE
为2时,以下代码将出错:
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(parser)
if shuffle:
dataset = dataset.shuffle(buffer_size=128)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(count=1)
问题在于dataset = dataset.batch(batch_size)
,当Dataset
循环进入最后一批时,剩余的样本数量只有1,所以有没有办法从之前访问过的样本中随机挑选一个并生成最后一批?
@mining通过填充文件名来提出解决方案。
另一种解决方案是使用tf.contrib.data.batch_and_drop_remainder
。这将使用固定的批处理大小批处理数据并删除最后一个较小的批处理。
在您的示例中,使用11个输入和批量大小为2,这将产生5批2个元素。
以下是文档中的示例:
dataset = tf.data.Dataset.range(11)
batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(2))
你可以在你的drop_remainder=True
电话中设置batch
。
dataset = dataset.batch(batch_size, drop_remainder=True)
drop_remainder :(可选。)一个tf.bool标量tf.Tensor,表示在少于batch_size元素的情况下是否应删除最后一批;默认行为是不删除较小的批处理。