使用 hooks 在 Pytorch 中向后传递期间打印中间梯度值

问题描述 投票:0回答:1

我正在尝试使用寄存器向后钩子在模型向后传递期间打印每个中间梯度的值:

class func_NN(torch.nn.Module):
    def __init__(self,) :
        super().__init__()
        self.a = torch.nn.Parameter(torch.rand(1))
        self.b = torch.nn.Parameter(torch.rand(1))

    def forward(self, inp):
        mul_x = torch.cos(self.a.view(-1,1)*inp)
        sum_x = mul_x - self.b
        return sum_x

# hook function
def backward_hook(module, grad_input, grad_output):
    print("module: ", module)
    print("inp: ", grad_input)
    print("out: ", grad_output) 

# Training
# Generate labels
a = torch.Tensor([0.5])
b = torch.Tensor([0.8])
x = torch.linspace(-1, 1, 10)
y = a*x + (0.1**0.5)*torch.randn_like(x)*(0.001) + b
inp = torch.linspace(-1, 1, 10)
foo = func_NN()
handle_ = foo.register_full_backward_hook(backward_hook)
loss = torch.nn.MSELoss()
optim = torch.optim.Adam(foo.parameters(),lr=0.001)

t_l = []
for i in range(2):
    optim.zero_grad()
    l = loss(y, foo.forward(inp=inp))
    t_l.append(l.detach())
    l.backward()
    optim.step()
handle_.remove()

但这并没有提供预期的结果。

我的目标是打印非叶节点的梯度,例如

sum_x
mul_x
。 请帮忙。

python pytorch tensor backpropagation autograd
1个回答
0
投票

Pytorch hooks 旨在抓取与参数相关的梯度。您不能使用它们来获取中间张量的梯度。

如果你想获得中间张量的梯度,你需要将它们保存到模型的状态中并对其应用

retain_grad

class func_NN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.a = torch.nn.Parameter(torch.rand(1))
        self.b = torch.nn.Parameter(torch.rand(1))

    def forward(self, inp):
        mul_x = torch.cos(self.a.view(-1, 1) * inp)
        sum_x = mul_x - self.b

        # Retain gradients for intermediate variables
        mul_x.retain_grad()
        sum_x.retain_grad()

        # Store references to the intermediate tensors
        self.mul_x = mul_x
        self.sum_x = sum_x

        return sum_x

a = torch.Tensor([0.5])
b = torch.Tensor([0.8])
x = torch.linspace(-1, 1, 10)
y = a*x + (0.1**0.5)*torch.randn_like(x)*(0.001) + b
inp = torch.linspace(-1, 1, 10)
foo = func_NN()
loss = torch.nn.MSELoss()


l = loss(y, foo.forward(inp=inp))
l.backward()

print(foo.mul_x.grad)
print(foo.sum_x.grad)
© www.soinside.com 2019 - 2024. All rights reserved.