在神经网络实现中耗尽内存(使用Numpy数组)

问题描述 投票:0回答:1

我的数据集采用以下形式:

培训数据

一个numpy数组(7855,448,448,3),其中(448,448,3)是RGB图像的numpy版本。由于网络的目的是回归,我还没有找到使用ImageDataGenerator的解决方案。所以,我已经将整个图像数据集转换为numpy数组。

培训目标

训练目标是大小为7855的1维numpy阵列。条​​目对应于训练数据的条目。

为了掌握numpy数组,我必须将整个数据集加载到内存中的变量中,然后传递它以适应和预测。这仅需要5到6演出的RAM。

在拟合模型时,RAM会快速溢出,运行时崩溃。如何批量提供numpy数组元素,或者是否有另一种加载数据集的方式:

|list of images |
|labelled       |
|1, 2, 3...     |
|n              |


|csv file with: |
|1   target1    |
|2   target2    |
|3   target3... |

代码https://colab.research.google.com/drive/1FUvPcpYiDtli6vwIaTwacL48RwZ0sq-9

[我一直在使用谷歌Colab,因为这是一个学术研究项目,还没有投资高端服务器。 ]

python numpy tensorflow keras deep-learning
1个回答
0
投票

您需要使用数据集API。当您创建numpy数组,train_images,train_target时,请使用tf.data.Dataset.from_tensor_slices

dataset = tf.data.Dataset.from_tensor_slices((train_images, train_target))

这将创建数据集对象,可以将其输入model.fit您可以将任何解析函数随机,批处理和映射到此数据集。您可以控制将使用shuffle缓冲区预加载的示例数。重复控制epoch计数,最好留下None,所以它将无限重复。

dataset = dataset.shuffle().repeat()
dataset = dataset.batch()

请记住,批处理在此管道内部进行,因此您不需要在model.fit中使用批处理,但是您需要在每个时期传递多个历元和步骤。后者可能有点棘手,因为你不能像len(dataset)那样做,所以应该提前计算。

model.fit(dataset, epochs, steps_per_epoch)

如果您遇到graphdef limit error,最好保存几个较小的numpy数组并将它们作为列表传递

让自己熟悉本指南qazxsw poi希望这会有所帮助。

© www.soinside.com 2019 - 2024. All rights reserved.