为什么使用detach()时无法更新权重?

问题描述 投票:0回答:1

我正在为学校项目编写 VGG 模型,但该模型在训练过程中会出现问题。

如果我对两个张量(scaled_similarity 和 target_tensor)使用 detach(),则模型的权重不会更新。但如果我不对张量使用 detach() ,则会出现一条错误消息。

“运行时错误:梯度计算所需的变量之一已通过就地操作修改:[torch.cuda.FloatTensor [1000, 800]],它是 AsStridedBackward0 的输出 0,版本为 14;预期版本为 13。提示:上面的回溯显示了未能计算其梯度的操作。有问题的变量在此处或稍后在任何地方发生了更改。祝你好运!”

我该如何修改这个模型?

该模型是绘画风格识别AI。 将图像传输到由 VGG 模型修改的张量后。

此代码是在谷歌合作实验室实现的。 这段代码是用 python 编写的。

for i, (_image1, _label1) in enumerate(train_loader):
    optimizer.zero_grad()
    image1 = _image1.to(DEVICE)
    label1 = _label1[0]
    vector1_tensor = model(image1)

    if (i == 0): #Exception Case
      image2 = image1
      label2 = label1
      vector2_tensor = vector1_tensor

    similarity = Similarity(vector1_tensor, vector2_tensor)
    similarity_value = similarity.item()
    similarity_vector = [similarity_value]

    if label1 == label2:
      target_vector = [1]
    else :
      target_vector = [0]
    similarity_tensor = torch.tensor(similarity_vector).float()
    target_tensor = torch.tensor(target_vector).float()
    cost = loss(similarity_tensor, target_tensor)
    cost.requires_grad_(True)
    cost.backward()
    optimizer.step()

    if not i % 40:
      print (f'Epoch: {epoch+20:03d}/{EPOCH:03d} | '
             f'Batch {i:03d}/{len(train_loader):03d} |'
             f' Cost: {cost:.4f}')

    #연산량 감소를 위한 텐서 재활용
    image2 = image1
    label2 = label1
    vector2_tensor = vector1_tensor

  PATH = "model_weights.pth"
  torch.save(model.state_dict(), PATH)

希望权重更新没有任何问题,也希望不会出现错误。

python pytorch conv-neural-network artificial-intelligence vgg-net
1个回答
0
投票

detach()
用于从图中分离张量,然后它与反向传播和权重更新无关。 如果后续计算与损失无关,或者如果计算不会反向传播,则应该使用它,否则您将立即出现错误,或者迟早会出现内存不足错误,因为所有这些步骤不必要地保留在内存中。

我无法用您的代码验证错误,我没有得到它们。 然而,在将旧结果提供给下一次迭代之前,从逻辑上将其分离,否则图形将加倍:

if (i == 0): #Exception Case
      image2 = image1 # inputs have no gradients
      label2 = label1
      vector2_tensor = vector1_tensor.detach() # arguable you could try to leave the detach out here
      # Note: do you want this behavior also in your second epoch?

...
vector2_tensor = vector1_tensor.detach() # we used this for loss calculation and do not want to use it for the next loop.
# <end of loop>
© www.soinside.com 2019 - 2024. All rights reserved.