我们如何以优雅的方式捕获使用optimizer.step()完成的更新?

问题描述 投票:0回答:1

我想实现一种方法,按照 Karpathy 视频中提到的想法,在 Tensorboard 中监控 PyTorch 训练期间的更新数据比率。我已经提出了一个解决方案,但我正在寻找一种更优雅和可配置的方法。

当前的实现直接修改训练循环如下:

for step, batch in data_loader:
    x, y = batch
    optimizer.zero_grad()
    for name, param in model.named_parameters():
        if param.requires_grad and "weight" in name:
            param.data_before_step = param.data.clone()
    output = model(x)
    loss = loss_fn(output, y)
    loss.backward()
    optimizer.step()
    lr_scheduler.step()
    for name, param in model.named_parameters():
        if hasattr(param, "data_before_step"):
            update = param.data - param.data_before_step
            update_to_data = (update.std() / param.data_before_step.std()).log10().item()
            summary_writer.add_scalar(f"Update:data ratio {name}", update_to_data, epoch * len(data_loader) + step)
            param.data_before_step = param.data.clone()

但是,这种方法直接在训练循环中添加代码,这可能会使代码变得混乱,如果我们想要使其可配置,则需要 if-else 语句,这会使代码更加混乱。

我还探索过使用 PyTorch hooks 来实现这一点。我已经成功实现了一个钩子来跟踪梯度:

class GradToDataRatioHook:
    def __init__(self, name, param, start_step, summary_writer):
        self.name = name
        self.param = param
        self.summary_writer = summary_writer
        self.grads = []
        self.grads_to_data = []
        self.param.update_step = start_step

    def __call__(self, grad):
        self.grads.append(grad.std().item())
        self.grads_to_data.append((grad.std() / (self.param.data.std() + 1e-5)).log10().item())
        self.summary_writer.add_scalar(f"Grad {self.name}", self.grads[-1], self.param.update_step)
        self.summary_writer.add_scalar(f"Grad:data ratio {self.name}", self.grads_to_data[-1], self.param.update_step)
        self.param.update_step += 1

但是,实现类似的钩子来捕获更新似乎很棘手。据我了解,

param.register_hook(...)
注册了钩子,该钩子在计算梯度时调用,即在调用
optimizer.step()
之前。虽然梯度和学习率为标准 SGD 提供了更新的直接值,但像 Adam 这样的现代优化器使更新过程变得更加复杂。我正在寻找一种以与优化器无关的方式捕获更新的解决方案,最好使用 PyTorch 挂钩。但是,任何建议或替代方法也将不胜感激。

machine-learning pytorch hook
1个回答
0
投票

可能最简单的方法是在每次

backward
调用后转储优化器的参数组和状态字典。这将让您捕获参数、梯度和优化器状态。

下面的代码显示了一个简单的 MLP 示例,但是

log_state
函数应该适用于所有模型,只要参数组/状态字典没有多层嵌套。

import torch
import torch.nn as nn

# example model
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        h = self.fc1(x)
        pred = self.fc2(self.relu(h))
        return pred

def log_state(opt):
    output = {}

    # log state dict
    state_dict = {}
    for key, value in opt.state_dict().items():
        if key == 'state':
            state_dict[key] = {}
            for state_key, state_value in value.items():
                state_dict[key][state_key] = {}
                for k, v in state_value.items():
                    if torch.is_tensor(v):
                        state_dict[key][state_key][k] = v.cpu().clone() # move tensor to cpu
                    else:
                        state_dict[key][state_key][k] = v
        else:
            state_dict[key] = value

    output['state_dict'] = state_dict

    # log param groups
    param_groups = []
    for group in opt.param_groups:
        param_group = {key: value for key, value in group.items() if key != 'params'}
        param_group['params'] = []
        param_group['param_grads'] = []
        for param in group['params']:
            param_group['params'].append(param.data.cpu().clone()) # move tensor to cpu
            # log gradients
            if param.grad is not None:
                param_group['param_grads'].append(param.grad.data.cpu().clone()) # move tensor to cpu
            else:
                param_group['param_grads'].append(None)

        param_groups.append(param_group)

    output['param_groups'] = param_groups
    
    return output



net = MLP(64, 20, 10)

opt = torch.optim.Adam(net.parameters(), lr=1e-3)

state_log = {}
for i in range(5):
    x = torch.randn(8, 64)
    y = torch.randn(8,10)
    p = net(x)
    loss = nn.functional.mse_loss(p, y)
    opt.zero_grad()
    loss.backward()
    state_log[i] = log_state(opt) # log optimizer state every step
    opt.step()

代码在

backward
之后记录以获取渐变。您可以在
step
之后添加额外的日志记录,但在
step
之后更新的值(参数、优化器平均值等)将在下一次迭代中捕获。您还可以使用一次迭代中的值来重新创建下一次迭代。例如,与亚当:

step = 2

old_state = state_log[step]
new_state = state_log[step+1]

# adam params
b1 = 0.9
b2 = 0.999
eps = 1e-8
lr = 1e-3

w = old_state['param_groups'][0]['params'][0]
g = old_state['param_groups'][0]['param_grads'][0]

m_old = old_state['state_dict']['state'][0]['exp_avg']
v_old = old_state['state_dict']['state'][0]['exp_avg_sq']

m = b1 * m_old + (1-b1)*g
v = b2 * v_old + (1-b2)*(g.pow(2))

m_hat = m.div(1-b1**(step+1))
v_hat = v.div(1-b2**(step+1))

w_new = w - lr * m_hat / (torch.sqrt(v_hat) + eps)

torch.allclose(w_new, new_state['param_groups'][0]['params'][0])
> True
© www.soinside.com 2019 - 2024. All rights reserved.