嵌套神经网络的局部阻塞梯度更新

问题描述 投票:0回答:1

我在 torch 中有两个嵌套的神经网络,我正在计算输出中相对于不同参数的多个损失。下面是一个简单的案例

# two neural networks
>>> A = nn.Linear(10,10)
>>> B = nn.Linear(10,1)

# dummy input
>>> x = torch.rand(1,10, requires_grad=True)

# nested computation
>>> y = B(A(x))

# evaluate two separate Loss functions on the output
>>> Loss1 = f(y)
>>> Loss2 = g(y)

# evaluate backprop through both losses
>>> (Loss1+Loss2).backward()

我希望 Loss1 一起跟踪网络 A 和 B 的梯度变化,但希望 Loss2 只跟踪网络 A 的变化。我知道我可以通过将计算分解为两个反向传播步骤来计算,例如

# two neural networks
>>> A = nn.Linear(10,10)
>>> B = nn.Linear(10,1)

# dummy input
>>> x = torch.rand(1,10, requires_grad=True)

# nested computation
>>> y = B(A(x))

# evaluate first loss function
>>> Loss1 = f(y)

# evaluate backprop through first loss
>>> Loss1.backward()

# disable gradient computation on B
>>> B.requires_grad_(False)

# nested computation
>>> y = B(A(x))

# evaluate second loss function
>>> Loss2 = g(y)

# evaluate backprop through second loss
>>> Loss2.backward()

我不喜欢这种方法,因为它需要通过嵌套神经网络进行多次反向传播计算。有没有办法将第二次丢失标记为不更新网络 B?我在想类似于

g(y).detach()
的东西,但这也消除了网络 A 的梯度。

pytorch loss-function backpropagation
1个回答
0
投票

您正在描述类似于 GAN 优化方法的东西,其中

A
是生成器,
B
是鉴别器。因此,最好将它与 PyTorch 这样的框架中的 GAN 的实现方式进行比较。您无法通过一次向后传递来分离两个梯度信号。你必须有两次向后传球。

|<------------------- L1 
|<---------•••••••••• L2
x ---> A ---> B ---> y 
© www.soinside.com 2019 - 2024. All rights reserved.