TF记录读取管道在PREFETCH采样后变慢

问题描述 投票:1回答:1

我将训练数据划分为多个tf记录文件,并使用以下代码读取它们:

SHUFFLE_BUFFER = 64
PREFETCH = 256
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.shuffle(SHUFFLE_BUFFER) 
dataset = dataset.map(_parse_image_function, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(PREFETCH)
dataset = dataset.repeat()

此数据集直接输入到model.fit(dataset)。

第一个PREFETCH样本被快速加载,并且GPU利用率始终高于80%。但是,此后,快速读取似乎停止了,GPU使用率下降了,培训时间大大减少了。有人知道出什么事了吗?

tensorflow tensorflow2.0 tensorflow-datasets tf.keras
1个回答
0
投票
这很难在不了解更多详细信息的情况下进行诊断(存储后端,记录大小,每个文件的记录数,文件数,_parse_image_function中的任何io操作?,...。]

我的第一个怀疑是在tf.data.TFRecordDataset(filenames)上-在下一个文件之后打开一个文件可能会引入延迟尖峰,这可能会暂时使数据集CPU管道枯竭。 (多个较小的文件可能也没有自动预读带来的好处)

我将尝试在tf.data.TFRecordDataset(filenames)之后添加一个额外的预取,以解除IO的耦合(并可能交错记录来自不同文件(num_parallel_reads参数))。

如果预取没有帮助,我将尝试对num_parallel_calls进行硬编码(主要是因为我尚未阅读自动调谐代码-如果您的管道需要更多的默认并行性,则可能使用私有线程池。)>

取决于您的存储后端-一旦训练变慢(以测试/优化数据集),重复训练就会重新开始,可能只是从各种缓存中提取数据,而一旦使用的数据集超过了缓存,则可能会变慢。

© www.soinside.com 2019 - 2024. All rights reserved.