如果你有多个网络(从
nn.Module
继承的多个对象的意义上),你必须这样做有一个简单的原因:当构造torch.nn.optim.Optimizer
对象时,它将需要优化的参数作为参数。对于你的情况:
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
这也让您可以自由地独立改变参数作为学习率。如果您不需要,您可以创建一个继承自
nn.Module
并包含网络、编码器和解码器的新类,或者创建一组参数以提供给优化器,如此处所解释:
nets = [encoder, decoder]
parameters = set()
for net in nets:
parameters |= set(net.parameters())
其中
|
是此上下文中集合的并集运算符。