我试图了解钩子在
PyTorch
中的具体代码操作方式。我有一个模型,我想在我的代码中设置一个前向和后向挂钩。我想在模型中的特定层之后设置一个钩子,我想最简单的方法是为此特定的module
设置一个钩子。这个介绍性的 video 警告向后模块包含一个错误,但我不确定情况是否仍然如此。
我的代码如下:
def __init__(self, model, attention_layer_name='desired_name_module',discard_ratio=0.9):
self.model = model
self.discard_ratio = discard_ratio
for name, module in self.model.named_modules():
if attention_layer_name in name:
module.register_forward_hook(self.get_attention)
module.register_backward_hook(self.get_attention_gradient)
self.attentions = []
self.attention_gradients = []
def get_attention(self, module, input, output):
self.attentions.append(output.cpu())
def get_attention_gradient(self, module, grad_input, grad_output):
self.attention_gradients.append(grad_input[0].cpu())
def __call__(self, input_tensor, category_index):
self.model.zero_grad()
output = self.model(input_tensor)
loss = ...
loss.backward()
我很困惑地理解以下几行代码如何工作:
module.register_forward_hook(self.get_attention)
module.register_backward_hook(self.get_attention_gradient)
我正在向所需的模块注册一个钩子,但是,然后,我在每种情况下都调用一个函数,而无需任何输入。我的问题是
Python
,这个调用到底是如何工作的?当函数被调用时,register_forward_hook
和register_backward_hook
的参数如何运行?
当执行特定操作时,挂钩允许您执行特定函数(称为“回调”)。在这种情况下,您期望在访问
self.get_attention
的 forward
函数后调用 module
。举一个简单的例子来说明钩子的样子。我定义了一个简单的类,您可以通过register_hook
在其上注册新的回调,然后当调用实例时(通过__call__
),将使用提供的参数调用所有钩子:
class Obj:
def __init__(self):
self.hooks = []
def register_hook(self, hook):
self.hooks.append(hook)
def __call__(self, x, y):
print('instance called')
for hook in self.hooks:
hook(x, y)
首先初始化一个实例并实现两个钩子用于演示目的:
obj = Obj()
def foo(x, y):
print(f'foo called with {x} and {y}')
def bar(x, _):
print(f'bar called with {x}')
您可以注册一个钩子并调用实例:
>>> obj.register_hook(foo)
>>> obj('yes', 'no')
instance called
foo called with yes and no
并且可以在上面添加钩子,再次调用进行比较,这里两个钩子都会被触发:
>>> obj.register_hook(bar)
>>> obj('yes', 'no')
instance called
foo called with yes and no
bar called with yes
nn.Module
的转发过程上附加一个钩子,你应该使用register_forward_hook
,参数是一个回调函数,需要module
、args
和output
。每次转发执行时都会触发此回调。但是,对于向后挂钩,您应该使用 register_full_backward_hook
,注册的挂钩必须需要三个参数:module
、grad_input
和 grad_output
。截至最近的 PyTorch 版本,register_backward_hook
已被弃用,并且 不应使用。
这里的一个副作用是您正在使用
self.get_attention
和 self.get_attention_gradient
注册钩子。当传递给寄存器函数时,它们与类实例解除绑定!换句话说,它们将在没有 self
参数的情况下被调用:
self.get_attention(module, input, output)
self.get_attention_gradient(module, grad_input, grad_output)
所以这不起作用,解决这个问题的一个简单方法是在注册时用 lambda 包装钩子。比如:
module.register_forward_hook(
lambda *args, **kwargs: Routine.get_attention(self, *args, **kwargs))
总而言之,您的课程可能如下所示:
class Routine:
def __init__(self, model, attention_layer_name):
self.model = model
for name, module in self.model.named_modules():
if attention_layer_name in name:
module.register_forward_hook(
lambda *args, **kwargs: Routine.get_attention(self, *args, **kwargs))
module.register_full_backward_hook(
lambda *args, **kwargs: Routine.get_attention_gradient(self, *args, **kwargs))
self.attentions = []
self.attention_gradients = []
def get_attention(self, module, input, output):
self.attentions.append(output.cpu())
def get_attention_gradient(self, module, grad_input, grad_output):
print(grad_input)
self.attention_gradients.append(grad_input[0].cpu())
def __call__(self, input_tensor):
self.model.zero_grad()
output = self.model(input_tensor)
loss = output.mean()
loss.backward()
这是一个具有单个线性层模型的最小示例(假设您将类包装器命名为
Routine
):
routine = Routine(nn.Sequential(nn.Linear(10,10)), attention_layer_name='0')
调用实例时,首先用
self.model(input_tensor)
触发前向钩子,然后用loss.backward()
触发后向钩子。
>>> routine(torch.rand(1,10, requires_grad=True))
在实现之后,您的前向钩子将
"attention_layer_name"
层的输出缓存在 self.attentions
中。
>>> routine.attentions
[tensor([[-0.3137, -0.2265, -0.2197, 0.2211, -0.6700,
-0.5034, -0.1878, -1.1334, 0.2025, 0.8679]], grad_fn=<AddmmBackward0>)]
同样适用于
self.attention_gradients
:
>>> routine.attentions_gradients
[tensor([[-0.0501, 0.0393, 0.0353, -0.0257, 0.0083,
0.0426, -0.0004, -0.0095, -0.0759, -0.0213]])]
需要注意的是,缓存的输出和梯度将保留在
self.attentions
和 self.attentions_gradients
中,并在每次执行 Routine.__call__
时附加。