PyTorch 中的 register_forward_hook 和 register_module_forward_hook 有什么区别?

问题描述 投票:0回答:1

正如标题所示,我想了解这两个函数在 PyTorch 中作为前向钩子的功能如何?我看到 regisfter_module_forward_hook 添加了一个全局状态,我假设这意味着所有前向挂钩都有一个函数。是这样吗,或者它的功能与更常用的 register_forward_hook 有什么不同?

我最终编写的目的是从给定网络的所有层计算相同的统计信息,因此用作钩子的函数在所有层中都是相同的。后者是更好的选择吗?

我还没有尝试使用它们,因为我想找出哪一个更适合我的情况。

deep-learning pytorch neural-network
1个回答
0
投票

我只是想找出同样的问题,并在谷歌搜索时发现了你的问题。

一些挖掘:

  • register_forward_hook
    已添加于 此 PR 7 年
  • register_module_forward_hook
    于 3 年前添加于 此 PR

似乎前者需要在每个模块的基础上设置,而后者是您设置一次为每个模块运行的全局钩子。

 test_fwd = nn.modules.module.register_module_forward_hook(lambda *args: fw_hook(1, *args))

register_module_forward_hook寻找

责备
显示了这个相关问题以及 3 个月前的更多详细信息。

听起来后者是适合您情况的更好选择。特别是,考虑最新提交的评论,因为它使其与上下文管理器兼容。

例如,您可以使用它来通过使用这样的上下文管理器来计算每一层上的每个示例激活范数

@contextmanager
def module_hook(hook: Callable):
    handle = nn.modules.module.register_module_forward_hook(hook, always_call=True)
    yield
    handle.remove()

def compute_norms(layer: nn.Module, inputs: Tuple[torch.Tensor], _output: torch.Tensor):
    A = inputs[0].detach()
    layer.norms2 = (A * A).sum(dim=1)

with module_hook(compute_norms):
    outputs = model(data)

print("layer", "norms squared")
for name, layer in model.named_modules():
    if not name:
        continue
    print(f"{name:20s}: {layer.norms2.cpu().numpy()}")

来自 colab

的完整代码
from contextlib import contextmanager
from typing import Callable, Tuple

import torch
import torch.nn as nn

import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'
data = torch.tensor([[1., 0.], [1., 1.]]).to(device)
bs = data.shape[0]  # batch size

def simple_model(d, num_layers):
    """Creates simple linear neural network initialized to 2*identity"""
    layers = []
    for i in range(num_layers):
        layer = nn.Linear(d, d, bias=False)
        layer.weight.data.copy_(2 * torch.eye(d))
        layers.append(layer)
    return torch.nn.Sequential(*layers)

norms = [torch.zeros(bs).to(device)]

def compute_norms(layer: nn.Module, inputs: Tuple[torch.Tensor], _output: torch.Tensor):
    assert len(inputs) == 1, "multi-input layer??"
    A = inputs[0].detach()
    layer.norms2 = (A * A).sum(dim=1)

model = simple_model(2, 3).to(device)

@contextmanager
def module_hook(hook: Callable):
    handle = nn.modules.module.register_module_forward_hook(hook, always_call=True)
    yield
    handle.remove()

with module_hook(compute_norms):
    outputs = model(data)

np.testing.assert_allclose(model[0].norms2.cpu(), [1, 2])
np.testing.assert_allclose(model[1].norms2.cpu(), [4, 8])
np.testing.assert_allclose(model[2].norms2.cpu(), [16, 32])

print(f"{'layer':20s}: {'norms squared'}")
for name, layer in model.named_modules():
    if not name:
        continue
    print(f"{name:20s}: {layer.norms2.cpu().numpy()}")
#     print(name, layer.norms2)

assert not torch.nn.modules.module._global_forward_hooks, "Some hooks remain"
© www.soinside.com 2019 - 2024. All rights reserved.