tf.data: 并行加载步骤

问题描述 投票:7回答:1

我有一个数据输入流水线,它有。

  • 输入的数据点的类型是不能投射到的。tf.Tensor (口述)
  • 不能理解张量流类型的预处理函数,需要与这些数据点一起工作;其中一些函数在飞行中进行数据增强。

我一直想把这个装进一个... tf.data pipeline,我被卡在了并行运行多个数据点的预处理上。到目前为止,我已经尝试了以下方法。

  • 使用 Dataset.from_generator(gen) 并在生成器中进行预处理;这样做是可行的,但它会按顺序处理每个数据点,无论这些数据点是如何排列的 prefetch 和假冒 map 调用我对它进行修补。是不是不能并行预取?
  • 把预处理封装在一个 tf.py_function 好让我 map 在我的Dataset上同步进行,但
    1. 这需要一些非常丑陋的(去)序列化,以将外来类型融入到字符串 tensors 中。
    2. 显然,执行 py_function 会被移交给(单进程)python解释器,所以我只能使用python GIL,这对我帮助不大。
  • 我看到你可以做一些技巧与 interleave 但还没有找到任何没有前两个想法的问题。

我是否遗漏了什么?我是否被迫修改我的预处理,使其能够在图中运行,或者有什么方法可以进行多处理?

我们之前的方法是使用keras.Sequence,效果很好,但是有太多的人推崇升级为 tf.data API。(地狱,甚至尝试用tf 2.2的keras.Sequence也会得到 WARNING:tensorflow:multiprocessing can interact badly with TensorFlow, causing nondeterministic deadlocks. For high performance data pipelines tf.data is recommended.)

注:我用的是tf 2.2rc3。

python tensorflow tensorflow2.0 tensorflow-datasets
1个回答
1
投票

你可以尝试添加 batch() 之前 map() 在你的输入流水线中。

它通常是为了减少小地图函数调用的开销,见这里。https:/www.tensorflow.orgguidedata_performance#vectorizing_mapping

然而,你也可以用它来获取一批输入到你的地图上的数据 py_function 并使用python multiprocessing 在那里,以加快事情的发展。

这样你就可以绕过GIL的限制,这就使得 num_parallel_callstf.data.map() 无用 py_function 地图功能。

© www.soinside.com 2019 - 2024. All rights reserved.