将简单函数包装在类中的原因(PyTorch)

问题描述 投票:0回答:1

将简单函数(例如 torch.cat())(或层(例如 MaxPool2d))包装在这样的类中的原因是什么:

class Concat(nn.Module):
    def __init__(self, dimension=1):
        super(Concat, self).__init__()
        self.d = dimension

    def forward(self, x):
        return torch.cat(x, self.d)

class MP(nn.Module):
    def __init__(self, k=2):
        super(MP, self).__init__()
        self.m = nn.MaxPool2d(kernel_size=k, stride=k)

    def forward(self, x):
        return self.m(x)
python oop deep-learning pytorch computer-vision
1个回答
0
投票

最大的原因应该是它们会被注册到模型中(模型可以引用它们),再加上 pytorch 用户(像我:))大量使用 pytorch hooks 来干扰模型,因此最好能够如果需要附加一些钩子(用于调试、通过更改源代码修改模型行为等)

© www.soinside.com 2019 - 2024. All rights reserved.