如何在 TensorFlow 中实现大数据集的交叉验证而不将整个数据集加载到内存中?

问题描述 投票:0回答:1

我目前正在处理一个机器学习项目的大型数据集,并选择使用 TensorFlow 的 tf.data API 来高效管理数据加载和预处理,而无需将整个数据集加载到内存中。这种方法对我的初始训练效果很好。

但是我很难实施交叉验证。据我了解,TensorFlow 本身并不支持直接通过 tf.data API 进行交叉验证,而与 Keras 集成进行交叉验证似乎需要先将数据加载到内存中。这对我的使用来说是有问题的,因为立即将整个数据集加载到内存中违背了使用 tf.data 的目的。

我正在寻找一种解决方法或方法来实现与 TensorFlow 的按需数据加载兼容的交叉验证。理想情况下,我希望保持 tf.data 的内存效率,同时对模型的评估进行交叉验证。

有没有一种方法可以使用 Keras 或任何其他库进行交叉验证,而不需要我将所有数据集加载到内存中?

tensorflow machine-learning keras cross-validation tf.data.dataset
1个回答
0
投票

TensorFlow的tf.data不直接支持交叉验证,有多种方法可以在保持内存效率的同时实现它。以下是您的选择:

  1. 外部库:IterativeStratification:该库与您的 tf.data 管道无缝集成,允许您定义 k 倍交叉验证,而无需立即加载整个数据集。 (https://github.com/trent-b/iterative-stratification/blob/master/iterstrat/ml_stratifiers.py) Imbalanced-learn:该库提供了专门为大型数据集设计的各种交叉验证方法,其中一些方法具有内存效率并且与 tf.data 兼容。 (https://imbalanced-learn.org/

  2. 这不是真正的交叉验证,但它可以帮助减轻大型数据集中的过度拟合,同时节省内存。通过随机洗牌来训练您的模型,并根据验证性能采用提前停止。

  3. 基于云的解决方案(如果适用):Google Cloud AI Platform 和 Amazon SageMaker 等平台提供内置的 k 倍交叉验证功能,不会将整个数据集加载到内存中。这可能会产生额外的成本或平台限制。

  4. 带有手动分割的分层 KFold:这需要更多的手动操作,但仍然节省内存。使用 scikit-learn 等外部工具将数据预先拆分为多个折叠,然后使用过滤器为每个折叠创建单独的 tf.data 管道,以在训练期间访问相关数据。

© www.soinside.com 2019 - 2024. All rights reserved.