使用 pytorch 进行 k 折交叉验证

问题描述 投票:0回答:1

我正在训练将 k 折交叉验证添加到我的脚本中,在阅读了一些文档后,它说训练循环应该位于折叠循环内 但我不明白的是,数据加载器也应该位于折叠循环内,但就我而言,它不是 所以如果我想使用折叠循环外部定义的数据加载器并从内部调用它们 我怎样才能做到这一点 ? 这些是功能 def get_train_utils(opt, model_parameters):

            data augmentation 
             ...........
   train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=opt.batch_size,
                                           shuffle=(train_sampler is None),
                                           num_workers=opt.n_threads,
                                           pin_memory=True,
                                           sampler=train_sampler,
                                           worker_init_fn=worker_init_fn)
return return (train_loader, train_sampler, train_logger, train_batch_logger,
        optimizer, scheduler)

def get_val_utils(opt):
   data augmentation 
   ........
   val_loader = torch.utils.data.DataLoader(val_data,
                                         batch_size=(opt.batch_size //
                                                     opt.n_val_samples),
                                         shuffle=False,
                                         num_workers=opt.n_threads,
                                         pin_memory=True,
                                         sampler=val_sampler,
                                         worker_init_fn=worker_init_fn,
                                         collate_fn=collate_fn)
 return val_loader, val_logger

训练和验证循环在另一个函数中定义

def main_worker(index, opt):
 other code 
     
if not opt.no_train:
    (train_loader, train_sampler, train_logger, train_batch_logger,
     optimizer, scheduler) = get_train_utils(opt, parameters)
    if opt.resume_path is not None:
        opt.begin_epoch, optimizer, scheduler = resume_train_utils(
            opt.resume_path, opt.begin_epoch, optimizer, scheduler)
        if opt.overwrite_milestones:
            scheduler.milestones = opt.multistep_milestones
if not opt.no_val:
    val_loader, val_logger = get_val_utils(opt)

if opt.tensorboard and opt.is_master_node:
    from torch.utils.tensorboard import SummaryWriter
    if opt.begin_epoch == 1:
        tb_writer = SummaryWriter(log_dir=opt.result_path)
    else:
        tb_writer = SummaryWriter(log_dir=opt.result_path,
                                  purge_step=opt.begin_epoch)
else:
    tb_writer = None

prev_val_loss = None

   for i in range(opt.begin_epoch, opt.n_epochs + 1):
    if not opt.no_train:
        if opt.distributed:
            train_sampler.set_epoch(i)
        current_lr = get_lr(optimizer)
        train_epoch(i, train_loader, model, criterion, optimizer,# 
                    opt.device, current_lr, train_logger,
                    train_batch_logger, tb_writer, opt.distributed)
      
        if i % opt.checkpoint == 0 and opt.is_master_node:
            save_file_path = opt.result_path / 'save_{}.pth'.format(i)
            save_checkpoint(save_file_path, i, opt.arch, model, optimizer,
                            scheduler)


    if not opt.no_val:
      prev_val_loss = val_epoch(i, val_loader, model, criterion,#
                                  opt.device, val_logger, tb_writer,
                                  opt.distributed)
    
    if not opt.no_train and opt.lr_scheduler == 'multistep':
        scheduler.step()
    elif not opt.no_train and opt.lr_scheduler == 'plateau':
        scheduler.step(prev_val_loss)
python pytorch conv-neural-network cross-validation pytorch-dataloader
1个回答
0
投票

数据加载器是根据数据集创建的,数据集是通过 k 倍分割创建的。

你的要求 -

I want to use the dataloaders defined outside the fold loop
- 没有意义。数据加载器具有固定的数据集分割。使用 k 折叠需要您创建不同的分割。如果你想做 k 折交叉验证,你必须创建不同的数据集分割。

伪代码如下:

dataset = ...

for k in range(n_folds):
    train_dataset, valid_dataset = split_dataset(dataset)

    train_dataloader = ... # create from train_dataset
    valid_dataloader = ... # create from valid_dataset

    train_epoch(train_dataloader, valid_dataloader, ...)
© www.soinside.com 2019 - 2024. All rights reserved.