TensorFlow:有效地从大文件中读取(和随机播放)样本

问题描述 投票:0回答:1

设置

我有几十个中型文件(〜1G),对于给定的类,每个文件每行包含一个样本。在每个文件中,样本按非随机顺序排列,即文件A的第i个样本与文件B的第i个样本在某种程度上相关,因为数据是沿着每个类别的某个轴的样本(细节不重要)。


问题

读取和处理内存中的所有样本不是​​一种选择,因为(1)可能多达数百个文件(2)预处理后每个样本的内存占用量显着增加(例如,由于大的一键编码矢量)。

我的目标是从磁盘上有效地读取样本(或批次)并将其输入到我的tf.keras模型中。此外,我想重新排列每个时期之后将样品(或批次)送入网络的顺序。


我如何以合理的效率进行存档,即在训练期间我的GPU不会闲置?

python tensorflow keras tensorflow-datasets
1个回答
1
投票

这里是一个建议,假设您正在读取TFRecord文件。具体参数取决于每个示例的大小和可用资源:

import tensorflow as tf
ds = (tf.data.Dataset.list_files('data_dir/*.tfrecord')
      .cache()
      .repeat()
      .shuffle(1_000)
      .interleave(tf.data.TFRecordDataset, block_length=100,
                  # Optional
                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
      .shuffle(10_000)
      .map(record_parse_function)
      .batch(32)
      .prefetch(1))

无论如何,建议阅读guide about tf.datatf.data

© www.soinside.com 2019 - 2024. All rights reserved.