将简单函数(例如 torch.cat())(或层(例如 MaxPool2d))包装在这样的类中的原因是什么:
class Concat(nn.Module):
def __init__(self, dimension=1):
super(Concat, self).__init__()
self.d = dimension
def forward(self, x):
return torch.cat(x, self.d)
class MP(nn.Module):
def __init__(self, k=2):
super(MP, self).__init__()
self.m = nn.MaxPool2d(kernel_size=k, stride=k)
def forward(self, x):
return self.m(x)
最大的原因应该是它们会被注册到模型中(模型可以引用它们),再加上 pytorch 用户(像我:))大量使用 pytorch hooks 来干扰模型,因此最好能够如果需要附加一些钩子(用于调试、通过更改源代码修改模型行为等)