在没有retain_graph=True的情况下反向传播两个具有不同损失的网络?

问题描述 投票:0回答:1

我有两个依次执行昂贵计算的网络。

两者的损失目标是相同的,除了第二个网络的损失我想应用掩模。

如何在不使用retain_graph=True的情况下实现这一点?

# tenc          - network1
# unet          - network2

# the work flow is input->tenc->hidden_state->unet->output


params = []
params.append([{'params': tenc.parameters(), 'weight_decay': 1e-3, 'lr': 1e-07}])
params.append([{'params': unet.parameters(), 'weight_decay': 1e-2, 'lr': 1e-06}])
optimizer = torch.optim.AdamW(itertools.chain(*params), lr=1, betas=(0.9, 0.99), eps=1e-07, fused = True, foreach=False)
scheduler = custom_scheduler(optimizer=optimizer, warmup_steps= 30, exponent= 5, random=False)
scaler = torch.cuda.amp.GradScaler() 


loss = torch.nn.functional.mse_loss(model_pred, target, reduction='none')
loss_tenc = loss.mean()
loss_unet = (loss * mask).mean()

scaler.scale(loss_tenc).backward(retain_graph=True)
scaler.scale(loss_unet).backward()
scaler.unscale_(optimizer)

scaler.step(optimizer)
scaler.update()

scheduler.step()
optimizer.zero_grad(set_to_none=True)

loss_tenc应该只优化tenc参数,loss_unet只能优化unet。如有必要,我可能必须使用两种不同的优化器,但为了简单起见,我在这里将它们分组为一个。

python pytorch autograd
1个回答
0
投票

考虑到两个组件都连接到

model_pred
,您可以通过将两个损失项相加来进行一次反向传播:

loss_tenc = loss.mean()
loss_unet = (loss * mask).mean()

scaler.scale(loss_tenc + loss_unet).backward()
© www.soinside.com 2019 - 2024. All rights reserved.