Tensorflow 数据集导致内存泄漏

问题描述 投票:0回答:1

我正在为我的 xception 模型训练创建一个数据集,我这样做:

def load_dataset(height: int, width: int, path: str, kind: str, batch_size=32) -> tf.data.Dataset:
    directory = os.path.join(path, kind)

    return keras.utils.image_dataset_from_directory(
        directory=directory,
        labels='inferred',
        label_mode='categorical',
        batch_size=batch_size,
        image_size=(height, width))


def load_rebalanced_dataset(
        height: int,
        width: int,
        path: str,
        kind: str,
        batch_size=32,
        do_repeat: bool = True) -> (tf.data.Dataset, list[str]):
    dataset = load_dataset(height, width, path, kind, batch_size)
    classes = dataset.class_names
    num_classes = len(classes)
    class_datasets = []

    for class_idx in range(num_classes):
        class_datasets.append(dataset.filter(lambda x, y: tf.reduce_any(tf.equal(tf.argmax(y, axis=-1), class_idx))))

    balanced_ds = tf.data.Dataset.sample_from_datasets(class_datasets, [1.0 / num_classes] * num_classes)

    if do_repeat:
        balanced_ds = balanced_ds.repeat()

    balanced_ds = balanced_ds.cache().prefetch(tf.data.AUTOTUNE)

    return balanced_ds, classes

load_rebalanced_dataset
函数中的某些内容使 tf 填满所有可用 RAM 内存(64 GB)。它在前 17 个 epoch 中训练没有问题,直到出现 OOM 错误。我能做点什么吗?

python tensorflow tensorflow-datasets
1个回答
0
投票

我也遇到这个问题,最后找到了一些在可能情况下工作的人,这对你也有帮助。

尝试删除“.cache() 你:

balanced_ds = balanced_ds.cache().prefetch(tf.data.AUTOTUNE)

解决方案:

balanced_ds = balanced_ds.prefetch(tf.data.AUTOTUNE)

据我了解,cache()会在第一个时期在内存中构建所有数据,以便在下一个时期更快。 为什么它在第一个纪元上工作,然后崩溃,神秘......可能处理会增加一些内存积累,并且数据几乎使用你的容量......我不确定,但尝试一下,它对我有用。

© www.soinside.com 2019 - 2024. All rights reserved.