修剪 nn.Linear 权重会导致意外错误,需要稍微奇怪的解决方法。需要解释

问题描述 投票:0回答:1

这失败了

import torch

def test1():  
  layer = nn.Linear(100, 10)
  x = 5 - torch.sum(layer(torch.ones(100)))
  x.backward()
  layer.weight.data = layer.weight.data[:, :90]
  layer.weight.grad.data = layer.weight.grad.data[:, :90]
  x = 5 - torch.sum(layer(torch.ones(90)))
  x.backward()
test1()

有错误

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-bb36a010bd86> in <cell line: 10>()
      8     x = 5 - torch.sum(layer(torch.ones(90)))
      9     x.backward()
---> 10 test1()
     11 # and this works as well
     12 

2 frames
/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    249     # some Python versions print out the first line of a multi-line function
    250     # calls in the traceback and some print out the last line
--> 251     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    252         tensors,
    253         grad_tensors_,

RuntimeError: Function TBackward0 returned an invalid gradient at index 0 - got [10, 90] but expected shape compatible with [10, 100]

这有效

import torch

def test2():  
  layer = torch.nn.Linear(100, 10)
  x = 5 - torch.sum(layer(torch.ones(100)))
  x.backward()
  del x    #main change
  layer.weight.data = layer.weight.data[:, :90]
  layer.weight.grad.data = layer.weight.grad.data[:, :90]
  x = 5 - torch.sum(layer(torch.ones(90)))
  x.backward()
test2()

这也有效

import torch
def test3():  
  layer = torch.nn.Linear(100, 10)
  x = 5 - torch.sum(layer(torch.ones(100)))
  x.backward()
  layer.weight.data = layer.weight.data[:, :90]
  layer.weight.grad.data = layer.weight.grad.data[:, :90]
  layer.weight = torch.nn.Parameter(layer.weight)   #main change
  x = 5 - torch.sum(layer(torch.ones(90)))
  x.backward()
test3()

我在尝试实现一篇关于模型剪枝的论文时遇到了这个问题。我相信这与 autograd 图有关,但我不确定到底发生了什么。有什么解释为什么这些几乎相同的代码片段有效或失败?

python machine-learning pytorch artificial-intelligence autograd
1个回答
0
投票

就像您猜测的那样,问题出在进行反向传播时创建的计算图。

在此 pytorch 链接中可以找到更好的修剪方法: https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

您通常不会摆弄输入尺寸;你只需关闭一些权重——将它们设置为零。在第一种情况下,我认为它有效,因为您重置了图表;在第二种情况下,这是因为您将模型参数设置为权重的截断版本。

© www.soinside.com 2019 - 2024. All rights reserved.