Pytorch 的文档关于数据加载器仅在代码中说
def prepare_data(self):
# download
...
和
def setup(self, stage: Optional[str] = None):
# Assign train/val datasets for use in dataloaders
请解释
prepare_data
和 setup
之间的预期分离、它们之间可能发生哪些回调,以及为什么将某些内容放在其中一个而不是另一个中。
Trainer.fit
的 § Hooks的文档页面中提供的
LightningModule
函数的伪函数,您可以阅读:
def fit(self):
if global_rank == 0:
# prepare data is called on GLOBAL_ZERO only
prepare_data() ## <-- prepare_data
configure_callbacks()
with parallel(devices):
# devices can be GPUs, TPUs, ...
train_on_device(model)
def train_on_device(model):
# called PER DEVICE
on_fit_start()
setup("fit") ## <-- setup
configure_optimizers()
# the sanity check runs here
on_train_start()
for epoch in epochs:
fit_loop()
on_train_end()
on_fit_end()
teardown("fit")
您可以看到
prepare_data
仅被 global_rank == 0
调用,i.e. 它仅由单个处理器调用。原来你可以从prepare_data
的文档描述中读到:
LightningModule.prepare_data()
用它来下载和准备数据。使用多个进程(分布式设置)下载和保存数据将导致数据损坏。 Lightning 确保仅在单个进程中调用此方法,因此您可以在中安全地添加下载逻辑。
setup
会在所有进程上调用,您可以从上面的伪代码及其文档描述中读取:
LightningModule.setup(stage=None)
在拟合开始时调用(训练+验证)、验证、测试或预测。当您需要动态构建模型或调整模型的某些内容时,这是一个很好的钩子。当使用 DDP 时,每个进程都会调用这个 hook。
任何数据操作,例如数据下载,都只能在主进程中进行,以防止在多处理中调用prepare_data()时可能出现错误。
在多处理中有效工作的任何其他数据操作(例如从磁盘读取数据和以下数据集分割)都应在 setup() 中完成。