我的数据集采用以下形式:
培训数据
一个numpy数组(7855,448,448,3),其中(448,448,3)是RGB图像的numpy版本。由于网络的目的是回归,我还没有找到使用ImageDataGenerator的解决方案。所以,我已经将整个图像数据集转换为numpy数组。
培训目标
训练目标是大小为7855的1维numpy阵列。条目对应于训练数据的条目。
为了掌握numpy数组,我必须将整个数据集加载到内存中的变量中,然后传递它以适应和预测。这仅需要5到6演出的RAM。
在拟合模型时,RAM会快速溢出,运行时崩溃。如何批量提供numpy数组元素,或者是否有另一种加载数据集的方式:
|list of images |
|labelled |
|1, 2, 3... |
|n |
|csv file with: |
|1 target1 |
|2 target2 |
|3 target3... |
代码https://colab.research.google.com/drive/1FUvPcpYiDtli6vwIaTwacL48RwZ0sq-9
[我一直在使用谷歌Colab,因为这是一个学术研究项目,还没有投资高端服务器。 ]
您需要使用数据集API。当您创建numpy数组,train_images,train_target时,请使用tf.data.Dataset.from_tensor_slices
dataset = tf.data.Dataset.from_tensor_slices((train_images, train_target))
这将创建数据集对象,可以将其输入model.fit
您可以将任何解析函数随机,批处理和映射到此数据集。您可以控制将使用shuffle缓冲区预加载的示例数。重复控制epoch计数,最好留下None
,所以它将无限重复。
dataset = dataset.shuffle().repeat()
dataset = dataset.batch()
请记住,批处理在此管道内部进行,因此您不需要在model.fit
中使用批处理,但是您需要在每个时期传递多个历元和步骤。后者可能有点棘手,因为你不能像len(dataset)
那样做,所以应该提前计算。
model.fit(dataset, epochs, steps_per_epoch)
如果您遇到graphdef limit error,最好保存几个较小的numpy数组并将它们作为列表传递
让自己熟悉本指南qazxsw poi希望这会有所帮助。