PyTorch 中的前向和后向钩子到底是如何工作的

问题描述 投票:0回答:1

我试图了解钩子在

PyTorch
中的具体代码操作方式。我有一个模型,我想在我的代码中设置一个前向和后向挂钩。我想在模型中的特定层之后设置一个钩子,我想最简单的方法是为此特定的
module
设置一个钩子。这个介绍性的 video 警告向后模块包含一个错误,但我不确定情况是否仍然如此。

我的代码如下:

def __init__(self, model, attention_layer_name='desired_name_module',discard_ratio=0.9):
  self.model = model
  self.discard_ratio = discard_ratio
  for name, module in self.model.named_modules():
    if attention_layer_name in name:
        module.register_forward_hook(self.get_attention)
        module.register_backward_hook(self.get_attention_gradient)

  self.attentions = []
  self.attention_gradients = []

def get_attention(self, module, input, output):
  self.attentions.append(output.cpu())

def get_attention_gradient(self, module, grad_input, grad_output):
  self.attention_gradients.append(grad_input[0].cpu())

def __call__(self, input_tensor, category_index):
  self.model.zero_grad()
  output = self.model(input_tensor)
  loss = ...
  loss.backward()

我很困惑地理解以下几行代码如何工作:

module.register_forward_hook(self.get_attention)
module.register_backward_hook(self.get_attention_gradient)

我正在向所需的模块注册一个钩子,但是,然后,我在每种情况下都调用一个函数,而无需任何输入。我的问题是

Python
,这个调用到底是如何工作的?当函数被调用时,
register_forward_hook
register_backward_hook
的参数如何运行?

python pytorch hook
1个回答
0
投票

钩子如何工作?

当执行特定操作时,挂钩允许您执行特定函数(称为“回调”)。在这种情况下,您期望在访问

self.get_attention
forward
函数后调用
module
。举一个简单的例子来说明钩子的样子。我定义了一个简单的类,您可以通过
register_hook
在其上注册新的回调,然后当调用实例时(通过
__call__
),将使用提供的参数调用所有钩子:

class Obj:
    def __init__(self):
        self.hooks = []
    
    def register_hook(self, hook):
        self.hooks.append(hook)

    def __call__(self, x, y):
        print('instance called')
        for hook in self.hooks:
            hook(x, y)

首先初始化一个实例并实现两个钩子用于演示目的:

obj = Obj()

def foo(x, y):
    print(f'foo called with {x} and {y}')
def bar(x, _):
    print(f'bar called with {x}')

您可以注册一个钩子并调用实例:

>>> obj.register_hook(foo)
>>> obj('yes', 'no')
instance called
foo called with yes and no

并且可以在上面添加钩子,再次调用进行比较,这里两个钩子都会被触发:

>>> obj.register_hook(bar)
>>> obj('yes', 'no')
instance called
foo called with yes and no
bar called with yes

在 PyTorch 中使用钩子

要在

nn.Module
的转发过程上附加一个钩子,你应该使用
register_forward_hook
,参数是一个回调函数,需要
module
args
output
。每次转发执行时都会触发此回调。但是,对于向后挂钩,您应该使用
register_full_backward_hook
,注册的挂钩必须需要三个参数:
module
grad_input
grad_output
。截至最近的 PyTorch 版本,
register_backward_hook
已被弃用,并且 不应使用

这里的一个副作用是您正在使用

self.get_attention
self.get_attention_gradient
注册钩子。当传递给寄存器函数时,它们与类实例解除绑定!换句话说,它们将在没有
self
参数的情况下被调用:

self.get_attention(module, input, output)
self.get_attention_gradient(module, grad_input, grad_output)

所以这不起作用,解决这个问题的一个简单方法是在注册时用 lambda 包装钩子。比如:

module.register_forward_hook(
    lambda *args, **kwargs: Routine.get_attention(self, *args, **kwargs))

总而言之,您的课程可能如下所示:

class Routine:
    def __init__(self, model, attention_layer_name):
        self.model = model

        for name, module in self.model.named_modules():
            if attention_layer_name in name:
                module.register_forward_hook(
                    lambda *args, **kwargs: Routine.get_attention(self, *args, **kwargs))
                module.register_full_backward_hook(
                    lambda *args, **kwargs: Routine.get_attention_gradient(self, *args, **kwargs))

        self.attentions = []
        self.attention_gradients = []

    def get_attention(self, module, input, output):
        self.attentions.append(output.cpu())

    def get_attention_gradient(self, module, grad_input, grad_output):
        print(grad_input)
        self.attention_gradients.append(grad_input[0].cpu())

    def __call__(self, input_tensor):
        self.model.zero_grad()
        output = self.model(input_tensor)
        loss = output.mean()
        loss.backward()

这是一个具有单个线性层模型的最小示例(假设您将类包装器命名为

Routine
):

routine = Routine(nn.Sequential(nn.Linear(10,10)), attention_layer_name='0')

调用实例时,首先用

self.model(input_tensor)
触发前向钩子,然后用
loss.backward()
触发后向钩子。

>>> routine(torch.rand(1,10, requires_grad=True))

在实现之后,您的前向钩子将

"attention_layer_name"
层的输出缓存在
self.attentions
中。

>>> routine.attentions
[tensor([[-0.3137, -0.2265, -0.2197,  0.2211, -0.6700, 
          -0.5034, -0.1878, -1.1334,  0.2025,  0.8679]], grad_fn=<AddmmBackward0>)]

同样适用于

self.attention_gradients

>>> routine.attentions_gradients
[tensor([[-0.0501,  0.0393,  0.0353, -0.0257,  0.0083,  
           0.0426, -0.0004, -0.0095, -0.0759, -0.0213]])] 

需要注意的是,缓存的输出和梯度将保留在

self.attentions
self.attentions_gradients
中,并在每次执行
Routine.__call__
时附加。

最新问题
© www.soinside.com 2019 - 2024. All rights reserved.