我尝试使用pytorch进行自动渐变。当我正在测试时,我遇到了错误。我的代码如下:
w11 = torch.rand((100,2), requires_grad=True)
w12 = torch.rand((100,2), requires_grad=True)
w12[:,1] = w12[:,1] + 1
w13 = torch.rand((100,2), requires_grad=True)
w13[:,1] = w13[:,1] + 2
out1=(w11-w12)**2
out2=out1.mean()
out2.backward(retain_graph=True)
当您想用require_grad = True替换张量中的某些内容时,使用with torch.no_grad()
,>
w11 = torch.rand((100,2), requires_grad=True) w12 = torch.rand((100,2), requires_grad=True) w13 = torch.rand((100,2), requires_grad=True) with torch.no_grad(): w12[:,1] = w12[:,1] + 1 w13[:,1] = w13[:,1] + 2 out1=(w11-w12)**2 out2=out1.mean() out2.backward(retain_graph=True)
一切都会顺利。