Keras的ImageDataGenerator看起来非常适合简单地逐步加载图像并将迭代器传递给model.fit函数。但是,它似乎仅可用于图像和分类任务。
我想进行回归,即我的标签也是与训练集相同形状的数组。实际上,它们是类似于图像的多维(> 1个通道)数组,但它们不是图像。
关于使用什么类将批数据简单地吐到keras model.fit()上以训练深度神经网络的任何建议?
当然,问题是我的数据集太大而无法容纳在内存中,这就是为什么我需要使用这些生成器/迭代器的原因。
一种快速的解决方案是使用Colab,其GPU实例具有24 GB的RAM可以使用。您也可以像我here
那样加载numpy数组时减少内存。适合您的情况的最佳解决方案是使用tf.data.Dataset()
。
虽然可能需要相对较短的时间来习惯它,但建议的方式是加载数据并使用model.fit()。
您可以在这里查阅文档:https://www.tensorflow.org/api_docs/python/tf/data/Dataset
它是新的,快速的,设计精美且易于扩展的。
例如,对于您的问题,您可能需要使用tf.data.Dataset.from_tensor_slices()
;我会让您发现它的功能:D。