如何解决pytorch中剪枝模型的深度复制错误

问题描述 投票:0回答:1

我正在尝试构建一个强化学习模型,其中我的参与者网络有一些经过修剪的连接。 当使用 torchrl 的数据收集器 SyncDataCollector 时,深度复制失败(请参见下面的错误)。

这似乎是由于修剪的连接造成的,它按照这篇文章中的建议使用gradfn(而不是requires_grad = True)设置修剪的层。

这是我想要运行的代码示例,其中 SyncDataCollector 尝试对模型进行深度复制,

device = torch.device("cpu")

model = nn.Sequential(
    nn.Linear(1,5),
    nn.Linear(5,1)
)
mask = torch.tensor([1,0,0,1,0]).reshape(-1,1)
prune.custom_from_mask(model[0], name='weight', mask=mask)


policy_module = TensorDictModule(
    model, in_keys=["in"], out_keys=["out"]
)

env = FlyEnv()

collector = SyncDataCollector(
    env,
    policy_module,
    frames_per_batch=1,
    total_frames=2,
    split_trajs=False,
    device=device,
)

这是一个产生错误的最小示例

import torch
from torch import nn
from copy import deepcopy

import torch.nn.utils.prune as prune

device = torch.device("cpu")

model = nn.Sequential(
    nn.Linear(1,5),
    nn.Linear(5,1)
)
mask = torch.tensor([1,0,0,1,0]).reshape(-1,1)
prune.custom_from_mask(model[0], name='weight', mask=mask)

new_model = deepcopy(model)

错误在哪里

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment.  If you were attempting to deepcopy a module, this may be because of a torch.nn.utils.weight_norm usage, see https://github.com/pytorch/pytorch/pull/103001

我尝试使用

prune.remove(model[0], 'weight')
删除修剪,然后设置
model[0].requires_grad_()
,这修复了结果,但随后所有权重都被训练了......

我认为通过在每次前向传递之前屏蔽它们来“手动”屏蔽修剪后的权重可能会起作用,但它看起来并不高效(也不优雅)。

python pytorch reinforcement-learning deep-copy pruning
1个回答
0
投票

导致错误的原因是参数被移动到

<param>_orig
并且屏蔽值存储在它旁边。 当 SyncDataCollector 取出参数和缓冲区并将它们放在“元”设备上以创建无状态策略时,这些附加值将被忽略,因为它们不再是参数(因此不会被调用
"to"
捕获)。

您可以拨打电话来解决问题

policy_module.module[0].weight = policy_module.module[0].weight.detach()

在创建收集器之前。 这应该没问题,因为

weight
属性无论如何都会在下一次转发调用期间重新计算。

TorchRL 也许应该更好地处理深度复制,尽管在这种情况下,错误是由张量在不应该的地方需要梯度引起的。 IMO 修剪方法应该在前向调用期间计算

"weight"
(就像它们所做的那样),然后修剪

© www.soinside.com 2019 - 2024. All rights reserved.