我正在训练将 k 折交叉验证添加到我的脚本中,在阅读了一些文档后,它说训练循环应该位于折叠循环内 但我不明白的是,数据加载器也应该位于折叠循环内,但就我而言,它不是 所以如果我想使用折叠循环外部定义的数据加载器并从内部调用它们 我怎样才能做到这一点 ? 这些是功能 def get_train_utils(opt, model_parameters):
data augmentation
...........
train_loader = torch.utils.data.DataLoader(train_data,
batch_size=opt.batch_size,
shuffle=(train_sampler is None),
num_workers=opt.n_threads,
pin_memory=True,
sampler=train_sampler,
worker_init_fn=worker_init_fn)
return return (train_loader, train_sampler, train_logger, train_batch_logger,
optimizer, scheduler)
和
def get_val_utils(opt):
data augmentation
........
val_loader = torch.utils.data.DataLoader(val_data,
batch_size=(opt.batch_size //
opt.n_val_samples),
shuffle=False,
num_workers=opt.n_threads,
pin_memory=True,
sampler=val_sampler,
worker_init_fn=worker_init_fn,
collate_fn=collate_fn)
return val_loader, val_logger
训练和验证循环在另一个函数中定义
def main_worker(index, opt):
other code
if not opt.no_train:
(train_loader, train_sampler, train_logger, train_batch_logger,
optimizer, scheduler) = get_train_utils(opt, parameters)
if opt.resume_path is not None:
opt.begin_epoch, optimizer, scheduler = resume_train_utils(
opt.resume_path, opt.begin_epoch, optimizer, scheduler)
if opt.overwrite_milestones:
scheduler.milestones = opt.multistep_milestones
if not opt.no_val:
val_loader, val_logger = get_val_utils(opt)
if opt.tensorboard and opt.is_master_node:
from torch.utils.tensorboard import SummaryWriter
if opt.begin_epoch == 1:
tb_writer = SummaryWriter(log_dir=opt.result_path)
else:
tb_writer = SummaryWriter(log_dir=opt.result_path,
purge_step=opt.begin_epoch)
else:
tb_writer = None
prev_val_loss = None
for i in range(opt.begin_epoch, opt.n_epochs + 1):
if not opt.no_train:
if opt.distributed:
train_sampler.set_epoch(i)
current_lr = get_lr(optimizer)
train_epoch(i, train_loader, model, criterion, optimizer,#
opt.device, current_lr, train_logger,
train_batch_logger, tb_writer, opt.distributed)
if i % opt.checkpoint == 0 and opt.is_master_node:
save_file_path = opt.result_path / 'save_{}.pth'.format(i)
save_checkpoint(save_file_path, i, opt.arch, model, optimizer,
scheduler)
if not opt.no_val:
prev_val_loss = val_epoch(i, val_loader, model, criterion,#
opt.device, val_logger, tb_writer,
opt.distributed)
if not opt.no_train and opt.lr_scheduler == 'multistep':
scheduler.step()
elif not opt.no_train and opt.lr_scheduler == 'plateau':
scheduler.step(prev_val_loss)
数据加载器是根据数据集创建的,数据集是通过 k 倍分割创建的。
你的要求 -
I want to use the dataloaders defined outside the fold loop
- 没有意义。数据加载器具有固定的数据集分割。使用 k 折叠需要您创建不同的分割。如果你想做 k 折交叉验证,你必须创建不同的数据集分割。
伪代码如下:
dataset = ...
for k in range(n_folds):
train_dataset, valid_dataset = split_dataset(dataset)
train_dataloader = ... # create from train_dataset
valid_dataloader = ... # create from valid_dataset
train_epoch(train_dataloader, valid_dataloader, ...)