通常,在构造函数中,我们声明要使用的所有层。在forward函数中,我们定义了模型如何从输入到输出运行。
我的问题是,如果直接在forward()
函数中调用那些预定义/内置的[[nn.Modules,该怎么办?这种Keras function API样式对Pytorch是否合法?如果没有,为什么?
TestModel
确实运行成功,没有发出警报。但是与传统方法相比,训练损失将缓慢下降。import torch.nn as nn
from cnn import CNN
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.num_embeddings = 2020
self.embedding_dim = 51
def forward(self, input):
x = nn.Embedding(self.num_embeddings, self.embedding_dim)(input)
# CNN is a customized class and nn.Module subclassed
# we will ignore the arguments for its instantiation
x = CNN(...)(x)
x = nn.ReLu()(x)
x = nn.Dropout(p=0.2)(x)
return output = x
scope。
例如,如果在模型的forward
函数中定义了一个conv层,则该“层”的范围及其可训练的参数对于该函数而言是[[local,并且在每次调用该函数后将被丢弃forward
方法。您无法更新和训练每次forward
通过后不断丢弃的砝码。但是,当conv层是model
的成员时,其范围将扩展到forward
方法之外,并且只要model
对象存在,可训练参数就会保留。这样,您可以更新和训练模型及其权重。