我的问题是关于pytorch register_hook
的语法。
x = torch.tensor([1.], requires_grad=True)
y = x**2
z = 2*y
x.register_hook(print)
y.register_hook(print)
z.backward()
输出:
tensor([2.])
tensor([4.])
这个片段分别打印了z
w.r.t x
和y
的渐变。
现在我的(最可能是微不足道的)问题是如何返回中间渐变(而不仅仅是打印)?
更新:
似乎调用retain_grad()
解决了叶节点的问题。恩。 y.retain_grad()
。
但是,retain_grad
似乎没有为非叶节点解决它。有什么建议?
我认为你可以使用这些钩子将渐变存储在一个全局变量中:
grads = []
x = torch.tensor([1.], requires_grad=True)
y = x**2 + 1
z = 2*y
x.register_hook(lambda d:grads.append(d))
y.register_hook(lambda d:grads.append(d))
z.backward()
但是你很可能还需要记住这些梯度的相应张量。在这种情况下,我们使用dict
而不是list
稍微扩展一下:
grads = {}
x = torch.tensor([1.,2.], requires_grad=True)
y = x**2 + 1
z = 2*y
def store(grad,parent):
print(grad,parent)
grads[parent] = grad.clone()
x.register_hook(lambda grad:store(grad,x))
y.register_hook(lambda grad:store(grad,y))
z.sum().backward()
现在,您可以使用y
访问张量grads[y]
的毕业