如何在pytorch中返回中间渐变(对于非叶节点)?

问题描述 投票:0回答:1

我的问题是关于pytorch register_hook的语法。

x = torch.tensor([1.], requires_grad=True)
y = x**2
z = 2*y

x.register_hook(print)
y.register_hook(print)

z.backward()

输出:

tensor([2.])
tensor([4.])

这个片段分别打印了z w.r.t xy的渐变。

现在我的(最可能是微不足道的)问题是如何返回中间渐变(而不仅仅是打印)?

更新:

似乎调用retain_grad()解决了叶节点的问题。恩。 y.retain_grad()

但是,retain_grad似乎没有为非叶节点解决它。有什么建议?

python gradient pytorch register-hook
1个回答
0
投票

我认为你可以使用这些钩子将渐变存储在一个全局变量中:

grads = []
x = torch.tensor([1.], requires_grad=True)
y = x**2 + 1
z = 2*y

x.register_hook(lambda d:grads.append(d))
y.register_hook(lambda d:grads.append(d))

z.backward()

但是你很可能还需要记住这些梯度的相应张量。在这种情况下,我们使用dict而不是list稍微扩展一下:

grads = {}
x = torch.tensor([1.,2.], requires_grad=True)
y = x**2 + 1
z = 2*y

def store(grad,parent):
    print(grad,parent)
    grads[parent] = grad.clone()

x.register_hook(lambda grad:store(grad,x))
y.register_hook(lambda grad:store(grad,y))

z.sum().backward()

现在,您可以使用y访问张量grads[y]的毕业

© www.soinside.com 2019 - 2024. All rights reserved.