如何在不丢失梯度的情况下屏蔽张量?

问题描述 投票:0回答:1

我有一个张量

import torch
a = torch.randn(1, 3, requires_grad=True)
print('a: ', a)
>>> a:  tensor([[0.0200, 1.00200, -4.2000]], requires_grad=True)

还有口罩

mask = torch.zeros_like(a)
mask[0][0] = 1

我想屏蔽我的张量

a
而不将梯度传播到我的掩模张量(在我的实际情况中它有一个梯度)。我尝试了以下

with torch.no_grad():
    b = a * mask
    print('b: ', b)
    >>> b:  tensor([[0.0200, 0.0000, -0.0000]])

但它完全从我的张量中删除了梯度。正确的做法是什么?

pytorch tensor
1个回答
0
投票

您正在丢失渐变,因为您正在使用

with torch.no_grad()
上下文。

with torch.no_grad()
内的所有代码都在没有梯度跟踪的情况下运行。如果您希望传播梯度,请不要使用它。

a = torch.randn(1, 3, requires_grad=True)

mask = torch.zeros_like(a)
mask[0][0] = 1

b = a * mask
print(b)
> tensor([[0.3871, 0.0000, -0.0000]], grad_fn=<MulBackward0>)
© www.soinside.com 2019 - 2024. All rights reserved.