我想实现一种方法,按照 Karpathy 视频中提到的想法,在 Tensorboard 中监控 PyTorch 训练期间的更新数据比率。我已经提出了一个解决方案,但我正在寻找一种更优雅和可配置的方法。
当前的实现直接修改训练循环如下:
for step, batch in data_loader:
x, y = batch
optimizer.zero_grad()
for name, param in model.named_parameters():
if param.requires_grad and "weight" in name:
param.data_before_step = param.data.clone()
output = model(x)
loss = loss_fn(output, y)
loss.backward()
optimizer.step()
lr_scheduler.step()
for name, param in model.named_parameters():
if hasattr(param, "data_before_step"):
update = param.data - param.data_before_step
update_to_data = (update.std() / param.data_before_step.std()).log10().item()
summary_writer.add_scalar(f"Update:data ratio {name}", update_to_data, epoch * len(data_loader) + step)
param.data_before_step = param.data.clone()
但是,这种方法直接在训练循环中添加代码,这可能会使代码变得混乱,如果我们想要使其可配置,则需要 if-else 语句,这会使代码更加混乱。
我还探索过使用 PyTorch hooks 来实现这一点。我已经成功实现了一个钩子来跟踪梯度:
class GradToDataRatioHook:
def __init__(self, name, param, start_step, summary_writer):
self.name = name
self.param = param
self.summary_writer = summary_writer
self.grads = []
self.grads_to_data = []
self.param.update_step = start_step
def __call__(self, grad):
self.grads.append(grad.std().item())
self.grads_to_data.append((grad.std() / (self.param.data.std() + 1e-5)).log10().item())
self.summary_writer.add_scalar(f"Grad {self.name}", self.grads[-1], self.param.update_step)
self.summary_writer.add_scalar(f"Grad:data ratio {self.name}", self.grads_to_data[-1], self.param.update_step)
self.param.update_step += 1
但是,实现类似的钩子来捕获更新似乎很棘手。据我了解,
param.register_hook(...)
注册了钩子,该钩子在计算梯度时调用,即在调用 optimizer.step()
之前。虽然梯度和学习率为标准 SGD 提供了更新的直接值,但像 Adam 这样的现代优化器使更新过程变得更加复杂。我正在寻找一种以与优化器无关的方式捕获更新的解决方案,最好使用 PyTorch 挂钩。但是,任何建议或替代方法也将不胜感激。
可能最简单的方法是在每次
backward
调用后转储优化器的参数组和状态字典。这将让您捕获参数、梯度和优化器状态。
下面的代码显示了一个简单的 MLP 示例,但是
log_state
函数应该适用于所有模型,只要参数组/状态字典没有多层嵌套。
import torch
import torch.nn as nn
# example model
class MLP(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
h = self.fc1(x)
pred = self.fc2(self.relu(h))
return pred
def log_state(opt):
output = {}
# log state dict
state_dict = {}
for key, value in opt.state_dict().items():
if key == 'state':
state_dict[key] = {}
for state_key, state_value in value.items():
state_dict[key][state_key] = {}
for k, v in state_value.items():
if torch.is_tensor(v):
state_dict[key][state_key][k] = v.cpu().clone() # move tensor to cpu
else:
state_dict[key][state_key][k] = v
else:
state_dict[key] = value
output['state_dict'] = state_dict
# log param groups
param_groups = []
for group in opt.param_groups:
param_group = {key: value for key, value in group.items() if key != 'params'}
param_group['params'] = []
param_group['param_grads'] = []
for param in group['params']:
param_group['params'].append(param.data.cpu().clone()) # move tensor to cpu
# log gradients
if param.grad is not None:
param_group['param_grads'].append(param.grad.data.cpu().clone()) # move tensor to cpu
else:
param_group['param_grads'].append(None)
param_groups.append(param_group)
output['param_groups'] = param_groups
return output
net = MLP(64, 20, 10)
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
state_log = {}
for i in range(5):
x = torch.randn(8, 64)
y = torch.randn(8,10)
p = net(x)
loss = nn.functional.mse_loss(p, y)
opt.zero_grad()
loss.backward()
state_log[i] = log_state(opt) # log optimizer state every step
opt.step()
代码在
backward
之后记录以获取渐变。您可以在 step
之后添加额外的日志记录,但在 step
之后更新的值(参数、优化器平均值等)将在下一次迭代中捕获。您还可以使用一次迭代中的值来重新创建下一次迭代。例如,与亚当:
step = 2
old_state = state_log[step]
new_state = state_log[step+1]
# adam params
b1 = 0.9
b2 = 0.999
eps = 1e-8
lr = 1e-3
w = old_state['param_groups'][0]['params'][0]
g = old_state['param_groups'][0]['param_grads'][0]
m_old = old_state['state_dict']['state'][0]['exp_avg']
v_old = old_state['state_dict']['state'][0]['exp_avg_sq']
m = b1 * m_old + (1-b1)*g
v = b2 * v_old + (1-b2)*(g.pow(2))
m_hat = m.div(1-b1**(step+1))
v_hat = v.div(1-b2**(step+1))
w_new = w - lr * m_hat / (torch.sqrt(v_hat) + eps)
torch.allclose(w_new, new_state['param_groups'][0]['params'][0])
> True