PytorchLightning:模型调用顺序

问题描述 投票:0回答:3

我正在尝试在 pytorch Lightning 之上重新实现训练管道。

在文档中,他们解释了训练/验证循环是这样执行的:

我的理解是顺序是:

  • train_step()
  • train_epoch_end()
  • val_step()
  • val_epoch_end()

我已经实现了一个虚拟代码来检查这一点:


import pytorch_lightning as pl
from torchmetrics import MeanMetric, SumMetric
from torch.utils.data import Dataset,DataLoader
import torch
import warnings
warnings.filterwarnings("ignore")

class DummyDataset(Dataset):
    def __init__(self):
        pass
    def __getitem__(self,idx):
        return torch.zeros([3,12,12]),torch.ones([3,12,12]) # Dummy image Like...
    def __len__(self):
        return 50

class DummyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3,3,1,1) # Useless convolution
        self.mean = MeanMetric()
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),lr=1e-3)
    def training_step(self, batch,batch_idx):
        x,y=batch
        y_hat = self(x)
        loss = torch.sum((y-y_hat)**2)
        self.mean.update(2)
        return loss

    def training_epoch_end(self, outputs):
        mean_train = self.mean.compute()
        print(f"\nmean_train is : {mean_train}\n")
        self.mean.reset()

    def validation_step(self, batch,batch_idx):
        x,y=batch
        y_hat = self(x)
        loss = torch.sum((y-y_hat)**2)
        self.mean.update(4)
        return loss

    def validation_epoch_end(self, outputs):
        mean_val = self.mean.compute()
        print(f"\nmean_val is : {mean_val}\n")
        self.mean.reset()

    def forward(self,x):
        return self.conv(x)

if __name__=='__main__':
    dataset = DummyDataset()
    train_loader=DataLoader(dataset,batch_size=4,num_workers=0)
    val_loader=DataLoader(dataset,batch_size=4,num_workers=0)
    model = DummyModel()
    # We create trainer
    trainer = pl.Trainer(val_check_interval=None)
    # We fit model
    trainer.fit(model,train_dataloaders=train_loader,val_dataloaders=val_loader)

我在输出中看到的是:

  • mean_val 为:3
  • mean_train 是:nan

它与我在调试器中看到的一致,并且顺序是:

  • train_step()
  • val_step() ...
  • val_epoch_end()
  • train_epoch_end()

是这样吗?
我做错了什么吗?
它是如何工作的?
谢谢!

python torch pytorch-lightning
3个回答
0
投票

您观察到的顺序是正确的。这是如何实现的草图:

for epoch in range(max_epocks):
    for i, batch in enumerate(train_dataloader):
    
        model.training_step(batch, i)
        
        if should_validate():
            for i, batch in enumerate(val_dataloader):
                 model.validation_step(i, batch)
            model.validation_epoch_end()
    
     model.training_epoch_end()

如您所见,验证循环位于训练循环内部,并且可能会在批次级别上触发。这可以通过

Trainer(val_check_interval=x)
在 Trainer 中进行配置,其中 x 表示每 x 批次。

但默认情况下,它将验证每个纪元,这意味着每个

len(train_dataloader)
,因此
should_validate
条件在该纪元的最后一批上为真。这就是为什么您会在打印中看到:

val_epoch_end()
train_epoch_end()

(它们基本上同时发生)。

我希望这个解释有帮助。


0
投票

闪电中训练师的调用顺序。它包括一些回调函数。


0
投票

https://github.com/Lightning-AI/lightning/discussions/14318

这是作者的回答

当您运行 trainer.fit(...) 时,默认情况下会调用两次validation_step(),然后调用validation_epoch_end(),然后再调用training_step()。

所有钩子的顺序记录在:https://pytorch-lightning.readthedocs.io/en/1.7.2/common/lightning_module.html#hooks

def fit_loop():
    on_train_epoch_start()

    for batch in train_dataloader():
        on_train_batch_start()

        on_before_batch_transfer()
        transfer_batch_to_device()
        on_after_batch_transfer()

        training_step()

        on_before_zero_grad()
        optimizer_zero_grad()

        on_before_backward()
        backward()
        on_after_backward()

        on_before_optimizer_step()
        configure_gradient_clipping()
        optimizer_step()

        on_train_batch_end()

        if should_check_val:
            val_loop()
    # end training epoch
    training_epoch_end()

    on_train_epoch_end()
© www.soinside.com 2019 - 2024. All rights reserved.