我正在为我的 xception 模型训练创建一个数据集,我这样做:
def load_dataset(height: int, width: int, path: str, kind: str, batch_size=32) -> tf.data.Dataset:
directory = os.path.join(path, kind)
return keras.utils.image_dataset_from_directory(
directory=directory,
labels='inferred',
label_mode='categorical',
batch_size=batch_size,
image_size=(height, width))
def load_rebalanced_dataset(
height: int,
width: int,
path: str,
kind: str,
batch_size=32,
do_repeat: bool = True) -> (tf.data.Dataset, list[str]):
dataset = load_dataset(height, width, path, kind, batch_size)
classes = dataset.class_names
num_classes = len(classes)
class_datasets = []
for class_idx in range(num_classes):
class_datasets.append(dataset.filter(lambda x, y: tf.reduce_any(tf.equal(tf.argmax(y, axis=-1), class_idx))))
balanced_ds = tf.data.Dataset.sample_from_datasets(class_datasets, [1.0 / num_classes] * num_classes)
if do_repeat:
balanced_ds = balanced_ds.repeat()
balanced_ds = balanced_ds.cache().prefetch(tf.data.AUTOTUNE)
return balanced_ds, classes
load_rebalanced_dataset
函数中的某些内容使 tf 填满所有可用 RAM 内存(64 GB)。它在前 17 个 epoch 中训练没有问题,直到出现 OOM 错误。我能做点什么吗?
我也遇到这个问题,最后找到了一些在可能情况下工作的人,这对你也有帮助。
尝试删除“.cache() 你:
balanced_ds = balanced_ds.cache().prefetch(tf.data.AUTOTUNE)
解决方案:
balanced_ds = balanced_ds.prefetch(tf.data.AUTOTUNE)
据我了解,cache()会在第一个时期在内存中构建所有数据,以便在下一个时期更快。 为什么它在第一个纪元上工作,然后崩溃,神秘......可能处理会增加一些内存积累,并且数据几乎使用你的容量......我不确定,但尝试一下,它对我有用。