何时在 pytorch Lightning 中使用prepare_data 与 setup?

问题描述 投票:0回答:2

Pytorch 的文档关于数据加载器仅在代码中说

def prepare_data(self):
    # download
    ...

def setup(self, stage: Optional[str] = None):
    # Assign train/val datasets for use in dataloaders

请解释

prepare_data
setup
之间的预期分离、它们之间可能发生哪些回调,以及为什么将某些内容放在其中一个而不是另一个中。

python deep-learning pytorch pytorch-lightning
2个回答
5
投票

如果您查看

Trainer.fit
§ Hooks
的文档页面中提供的 LightningModule 函数的伪函数,您可以阅读:

def fit(self):
    if global_rank == 0:
        # prepare data is called on GLOBAL_ZERO only
        prepare_data()                                 ## <-- prepare_data

    configure_callbacks()

    with parallel(devices):
        # devices can be GPUs, TPUs, ...
        train_on_device(model)


def train_on_device(model):
    # called PER DEVICE
    on_fit_start()
    setup("fit")                                       ## <-- setup
    configure_optimizers()

    # the sanity check runs here

    on_train_start()
    for epoch in epochs:
        fit_loop()
    on_train_end()

    on_fit_end()
    teardown("fit")

您可以看到

prepare_data
仅被
global_rank == 0
调用,i.e. 它仅由单个处理器调用。原来你可以从
prepare_data
的文档描述中读到:

LightningModule.prepare_data()

用它来下载和准备数据。使用多个进程(分布式设置)下载和保存数据将导致数据损坏。 Lightning 确保仅在单个进程中调用此方法,因此您可以在中安全地添加下载逻辑。

setup
会在所有进程上调用,您可以从上面的伪代码及其文档描述中读取:

LightningModule.setup(stage=None)

在拟合开始时调用(训练+验证)、验证、测试或预测。当您需要动态构建模型或调整模型的某些内容时,这是一个很好的钩子。当使用 DDP 时,每个进程都会调用这个 hook。


0
投票

任何数据操作,例如数据下载,都只能在主进程中进行,以防止在多处理中调用prepare_data()时可能出现错误。
在多处理中有效工作的任何其他数据操作(例如从磁盘读取数据和以下数据集分割)都应在 setup() 中完成。

© www.soinside.com 2019 - 2024. All rights reserved.