我正在尝试构建一个强化学习模型,其中我的参与者网络有一些经过修剪的连接。 当使用 torchrl 的数据收集器 SyncDataCollector 时,深度复制失败(请参见下面的错误)。
这似乎是由于修剪的连接造成的,它按照这篇文章中的建议使用gradfn(而不是requires_grad = True)设置修剪的层。
这是我想要运行的代码示例,其中 SyncDataCollector 尝试对模型进行深度复制,
device = torch.device("cpu")
model = nn.Sequential(
nn.Linear(1,5),
nn.Linear(5,1)
)
mask = torch.tensor([1,0,0,1,0]).reshape(-1,1)
prune.custom_from_mask(model[0], name='weight', mask=mask)
policy_module = TensorDictModule(
model, in_keys=["in"], out_keys=["out"]
)
env = FlyEnv()
collector = SyncDataCollector(
env,
policy_module,
frames_per_batch=1,
total_frames=2,
split_trajs=False,
device=device,
)
这是一个产生错误的最小示例
import torch
from torch import nn
from copy import deepcopy
import torch.nn.utils.prune as prune
device = torch.device("cpu")
model = nn.Sequential(
nn.Linear(1,5),
nn.Linear(5,1)
)
mask = torch.tensor([1,0,0,1,0]).reshape(-1,1)
prune.custom_from_mask(model[0], name='weight', mask=mask)
new_model = deepcopy(model)
错误在哪里
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment. If you were attempting to deepcopy a module, this may be because of a torch.nn.utils.weight_norm usage, see https://github.com/pytorch/pytorch/pull/103001
我尝试使用
prune.remove(model[0], 'weight')
删除修剪,然后设置 model[0].requires_grad_()
,这修复了结果,但随后所有权重都被训练了......
我认为通过在每次前向传递之前屏蔽它们来“手动”屏蔽修剪后的权重可能会起作用,但它看起来并不高效(也不优雅)。
导致错误的原因是参数被移动到
<param>_orig
并且屏蔽值存储在它旁边。
当 SyncDataCollector 取出参数和缓冲区并将它们放在“元”设备上以创建无状态策略时,这些附加值将被忽略,因为它们不再是参数(因此不会被调用"to"
捕获)。
您可以拨打电话来解决问题
policy_module.module[0].weight = policy_module.module[0].weight.detach()
在创建收集器之前。 这应该没问题,因为
weight
属性无论如何都会在下一次转发调用期间重新计算。
TorchRL 也许应该更好地处理深度复制,尽管在这种情况下,错误是由张量在不应该的地方需要梯度引起的。 IMO 修剪方法应该在前向调用期间计算
"weight"
(就像它们所做的那样),然后修剪