import torch
def test1():
layer = nn.Linear(100, 10)
x = 5 - torch.sum(layer(torch.ones(100)))
x.backward()
layer.weight.data = layer.weight.data[:, :90]
layer.weight.grad.data = layer.weight.grad.data[:, :90]
x = 5 - torch.sum(layer(torch.ones(90)))
x.backward()
test1()
有错误
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-3-bb36a010bd86> in <cell line: 10>()
8 x = 5 - torch.sum(layer(torch.ones(90)))
9 x.backward()
---> 10 test1()
11 # and this works as well
12
2 frames
/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
249 # some Python versions print out the first line of a multi-line function
250 # calls in the traceback and some print out the last line
--> 251 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
252 tensors,
253 grad_tensors_,
RuntimeError: Function TBackward0 returned an invalid gradient at index 0 - got [10, 90] but expected shape compatible with [10, 100]
import torch
def test2():
layer = torch.nn.Linear(100, 10)
x = 5 - torch.sum(layer(torch.ones(100)))
x.backward()
del x #main change
layer.weight.data = layer.weight.data[:, :90]
layer.weight.grad.data = layer.weight.grad.data[:, :90]
x = 5 - torch.sum(layer(torch.ones(90)))
x.backward()
test2()
import torch
def test3():
layer = torch.nn.Linear(100, 10)
x = 5 - torch.sum(layer(torch.ones(100)))
x.backward()
layer.weight.data = layer.weight.data[:, :90]
layer.weight.grad.data = layer.weight.grad.data[:, :90]
layer.weight = torch.nn.Parameter(layer.weight) #main change
x = 5 - torch.sum(layer(torch.ones(90)))
x.backward()
test3()
我在尝试实现一篇关于模型剪枝的论文时遇到了这个问题。我相信这与 autograd 图有关,但我不确定到底发生了什么。有什么解释为什么这些几乎相同的代码片段有效或失败?
就像您猜测的那样,问题出在进行反向传播时创建的计算图。
在此 pytorch 链接中可以找到更好的修剪方法: https://pytorch.org/tutorials/intermediate/pruning_tutorial.html
您通常不会摆弄输入尺寸;你只需关闭一些权重——将它们设置为零。在第一种情况下,我认为它有效,因为您重置了图表;在第二种情况下,这是因为您将模型参数设置为权重的截断版本。