将权重张量的一部分设置为requires_grad = True,并将其余值保留为requires_grad = False

问题描述 投票:0回答:1

我正在做某种迁移学习,我加载一个密集模型,然后扩展权重张量,并在扩展后仅训练新值,并保持旧的训练值冻结。在这种情况下,我需要在同一权重张量内将新权重设置为

requires_grad = True
,将旧权重设置为
requires_grad = False
。我尝试了这个,但它不起作用:

old_values = weight_mat[0, :, :length[0]] 
old_values.requires_grad = False # 1. I tried this and they got optimized
old_values = old_values.unsqueeze(0).detach() # 2. I tried this in addition to 1 and they get optimized
new_values = weight_mat[:, :, length[0]:]
new_values.requires_grad = True
weight_mat = torch.cat((old_values, new_values), dim=-1)

在打印不可训练模型的参数数量后,我得到 0,我还检查了历元内的权重张量值,发现所有值都已更新,而我将

old_values
设置为
False

python deep-learning pytorch
1个回答
0
投票

只是为了扩展我的评论,我认为您可以直接将同一张量的不同部分设置为具有不同的

requires_grad
设置。相反,您可以使用后向挂钩有选择地禁用旧值的渐变。向后钩子是一个 PyTorch 函数,可在向后传递期间计算梯度时促进自定义操作的执行。它可以应用于张量或模块。当修改梯度、实现自定义梯度计算或在计算梯度时检查梯度时,后向钩子是有益的。

PyTorch 支持两种主要类型的后向挂钩:

  • 张量后向钩子在特定张量上指定,并在发生这些张量的后向传递时被调用。
  • 模块的向后挂钩:这些挂钩在
    nn.Linear
    nn.Conv2d
    等模块上注册,当这些模块发生向后传递时将被调用。

在您的情况下,您将使用张量后向钩子,它将使用

weight_mat
直接在张量(
register_hook
张量)上注册。这个想法是使用后向钩子将权重张量的
old_values
部分的梯度归零,从而在训练期间有效地冻结这些值。

在 Python 中,代码可能如下所示:

import torch
import torch.nn as nn
import torch.optim as optim

class MyModel(nn.Module):
    def __init__(self, old_length, new_length):
        super(MyModel, self).__init__()
        self.total_length = old_length + new_length
        self.weight_mat = nn.Parameter(torch.randn(1, 1, self.total_length, requires_grad=True))
        self.old_length = old_length

    def forward(self, x):
        return torch.matmul(x, self.weight_mat)

    def zero_grad_old_values(self, grad):
        grad_clone = grad.clone()
        grad_clone[0, 0, :self.old_length] = 0
        return grad_clone

old_length = 5
new_length = 3

model = MyModel(old_length, new_length)

# Register hook
model.weight_mat.register_hook(model.zero_grad_old_values)

optimizer = optim.SGD(model.parameters(), lr=0.01)

input = torch.randn(1, 1, old_length + new_length)
target = torch.randn(1, 1, 1)


optimizer.zero_grad()
output = model(input)
loss = nn.MSELoss()(output, target)
loss.backward()
optimizer.step()

print("Weight matrix gradients after backward pass:")
print(model.weight_mat.grad)
© www.soinside.com 2019 - 2024. All rights reserved.