我正在尝试使用寄存器向后钩子在模型向后传递期间打印每个中间梯度的值:
class func_NN(torch.nn.Module):
def __init__(self,) :
super().__init__()
self.a = torch.nn.Parameter(torch.rand(1))
self.b = torch.nn.Parameter(torch.rand(1))
def forward(self, inp):
mul_x = torch.cos(self.a.view(-1,1)*inp)
sum_x = mul_x - self.b
return sum_x
# hook function
def backward_hook(module, grad_input, grad_output):
print("module: ", module)
print("inp: ", grad_input)
print("out: ", grad_output)
# Training
# Generate labels
a = torch.Tensor([0.5])
b = torch.Tensor([0.8])
x = torch.linspace(-1, 1, 10)
y = a*x + (0.1**0.5)*torch.randn_like(x)*(0.001) + b
inp = torch.linspace(-1, 1, 10)
foo = func_NN()
handle_ = foo.register_full_backward_hook(backward_hook)
loss = torch.nn.MSELoss()
optim = torch.optim.Adam(foo.parameters(),lr=0.001)
t_l = []
for i in range(2):
optim.zero_grad()
l = loss(y, foo.forward(inp=inp))
t_l.append(l.detach())
l.backward()
optim.step()
handle_.remove()
但这并没有提供预期的结果。
我的目标是打印非叶节点的梯度,例如
sum_x
和 mul_x
。
请帮忙。
Pytorch hooks 旨在抓取与参数相关的梯度。您不能使用它们来获取中间张量的梯度。
如果你想获得中间张量的梯度,你需要将它们保存到模型的状态中并对其应用
retain_grad
。
class func_NN(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.nn.Parameter(torch.rand(1))
self.b = torch.nn.Parameter(torch.rand(1))
def forward(self, inp):
mul_x = torch.cos(self.a.view(-1, 1) * inp)
sum_x = mul_x - self.b
# Retain gradients for intermediate variables
mul_x.retain_grad()
sum_x.retain_grad()
# Store references to the intermediate tensors
self.mul_x = mul_x
self.sum_x = sum_x
return sum_x
a = torch.Tensor([0.5])
b = torch.Tensor([0.8])
x = torch.linspace(-1, 1, 10)
y = a*x + (0.1**0.5)*torch.randn_like(x)*(0.001) + b
inp = torch.linspace(-1, 1, 10)
foo = func_NN()
loss = torch.nn.MSELoss()
l = loss(y, foo.forward(inp=inp))
l.backward()
print(foo.mul_x.grad)
print(foo.sum_x.grad)