我对PyTorch中的损失函数感到困惑。有些人将损失函数定义为普通的python函数,而另一些人则通过定义继承nn.Module的类来定义损失函数。所以我想知道在什么情况下我们需要通过继承nn.Module来定义损失函数?非常感谢。
通常,仅当您想在此模块中拥有可训练的变量时,才需要从nn.Module继承,否则可以选择继承。
nn.Module
同样适用于损失函数,如果它不包含此类变量(我认为是主要情况),则不需要继承。