如何修改 Adam 优化器以在计算中不包含零?

问题描述 投票:0回答:1

我在这个SO问题中找到了Adam的实现:

class ADAMOptimizer(torch.optim.Optimizer):
    """
    implements ADAM Algorithm, as a preceding step.
    """
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(ADAMOptimizer, self).__init__(params, defaults)

    def step(self):
        """
        Perform a single optimization step.
        """
        loss = None
        for group in self.param_groups:

            for p in group['params']:
                grad = p.grad.data
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Momentum (Exponential MA of gradients)
                    state['exp_avg'] = torch.zeros_like(p.data)

                    # RMS Prop componenet. (Exponential MA of squared gradients). Denominator.
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                b1, b2 = group['betas']
                state['step'] += 1

                # Add weight decay if any
                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], p.data)

                # Momentum
                exp_avg = torch.mul(exp_avg, b1) + (1 - b1)*grad
                
                # RMS
                exp_avg_sq = torch.mul(exp_avg_sq, b2) + (1-b2)*(grad*grad)

                mhat = exp_avg / (1 - b1 ** state['step'])
                vhat = exp_avg_sq / (1 - b2 ** state['step'])
                
                denom = torch.sqrt( vhat + group['eps'] )

                p.data = p.data - group['lr'] * mhat / denom 
                
                # Save state
                state['exp_avg'], state['exp_avg_sq'] = exp_avg, exp_avg_sq 

        return loss

我的问题是我的很多梯度都有 0 值,这会扰乱动量和速度项。我感兴趣的是修改代码,以便在计算动量和速度项(即第一和第二矩估计)时不考虑 0 值。

不过,我不确定该怎么做。如果它是一个简单的网络,其中梯度只是简单的维度,我可以检查是否

p.grad.data=0
,但由于这将是一个多维张量,我不确定如何删除计算中的零而不弄乱其他东西(例如,剩余的更新)。

machine-learning optimization pytorch neural-network gradient-descent
1个回答
0
投票

以下是实施这些更改的方法:

import torch

class ADAMOptimizer(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(ADAMOptimizer, self).__init__(params, defaults)

    def step(self):
        loss = None
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad.data
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                b1, b2 = group['betas']

                state['step'] += 1

                # Add weight decay if any
                if group['weight_decay'] != 0:
                    grad = grad.add(p.data, alpha=group['weight_decay'])

                # Create mask for non-zero gradients
                grad_nonzero = grad != 0

                # Update states only where gradients are non-zero
                exp_avg[grad_nonzero] = exp_avg[grad_nonzero] * b1 + (1 - b1) * grad[grad_nonzero]
                exp_avg_sq[grad_nonzero] = exp_avg_sq[grad_nonzero] * b2 + (1 - b2) * (grad[grad_nonzero] ** 2)

                # Bias-corrected first and second moment estimates
                mhat = exp_avg / (1 - b1 ** state['step'])
                vhat = exp_avg_sq / (1 - b2 ** state['step'])

                # Denominator calculation includes epsilon for numerical stability
                denom = (vhat.sqrt() + group['eps'])

                # Parameter update
                p.data = p.data - group['lr'] * mhat / denom

                # Save updated states
                state['exp_avg'], state['exp_avg_sq'] = exp_avg, exp_avg_sq

        return loss```
© www.soinside.com 2019 - 2024. All rights reserved.