我在这个SO问题中找到了Adam的实现:
class ADAMOptimizer(torch.optim.Optimizer):
"""
implements ADAM Algorithm, as a preceding step.
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(ADAMOptimizer, self).__init__(params, defaults)
def step(self):
"""
Perform a single optimization step.
"""
loss = None
for group in self.param_groups:
for p in group['params']:
grad = p.grad.data
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Momentum (Exponential MA of gradients)
state['exp_avg'] = torch.zeros_like(p.data)
# RMS Prop componenet. (Exponential MA of squared gradients). Denominator.
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
b1, b2 = group['betas']
state['step'] += 1
# Add weight decay if any
if group['weight_decay'] != 0:
grad = grad.add(group['weight_decay'], p.data)
# Momentum
exp_avg = torch.mul(exp_avg, b1) + (1 - b1)*grad
# RMS
exp_avg_sq = torch.mul(exp_avg_sq, b2) + (1-b2)*(grad*grad)
mhat = exp_avg / (1 - b1 ** state['step'])
vhat = exp_avg_sq / (1 - b2 ** state['step'])
denom = torch.sqrt( vhat + group['eps'] )
p.data = p.data - group['lr'] * mhat / denom
# Save state
state['exp_avg'], state['exp_avg_sq'] = exp_avg, exp_avg_sq
return loss
我的问题是我的很多梯度都有 0 值,这会扰乱动量和速度项。我感兴趣的是修改代码,以便在计算动量和速度项(即第一和第二矩估计)时不考虑 0 值。
不过,我不确定该怎么做。如果它是一个简单的网络,其中梯度只是简单的维度,我可以检查是否
p.grad.data=0
,但由于这将是一个多维张量,我不确定如何删除计算中的零而不弄乱其他东西(例如,剩余的更新)。
以下是实施这些更改的方法:
import torch
class ADAMOptimizer(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(ADAMOptimizer, self).__init__(params, defaults)
def step(self):
loss = None
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p.data)
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
b1, b2 = group['betas']
state['step'] += 1
# Add weight decay if any
if group['weight_decay'] != 0:
grad = grad.add(p.data, alpha=group['weight_decay'])
# Create mask for non-zero gradients
grad_nonzero = grad != 0
# Update states only where gradients are non-zero
exp_avg[grad_nonzero] = exp_avg[grad_nonzero] * b1 + (1 - b1) * grad[grad_nonzero]
exp_avg_sq[grad_nonzero] = exp_avg_sq[grad_nonzero] * b2 + (1 - b2) * (grad[grad_nonzero] ** 2)
# Bias-corrected first and second moment estimates
mhat = exp_avg / (1 - b1 ** state['step'])
vhat = exp_avg_sq / (1 - b2 ** state['step'])
# Denominator calculation includes epsilon for numerical stability
denom = (vhat.sqrt() + group['eps'])
# Parameter update
p.data = p.data - group['lr'] * mhat / denom
# Save updated states
state['exp_avg'], state['exp_avg_sq'] = exp_avg, exp_avg_sq
return loss```