我正在读下文。它使用EMA衰减变量。 Bidirectional Attention Flow for Machine Comprehension
在训练期间,模型的所有权重的移动平均值保持为0.999的指数衰减率。
他们使用TensorFlow,我找到了相关的EMA代码。 https://github.com/allenai/bi-att-flow/blob/master/basic/model.py#L229
在PyTorch中,如何将EMA应用于变量?
移动平均线是梯度下降动量的关键概念。
在PyTorch document你可以找到:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
将参数momentum
更改为您想要的值。