坦白说,这与 Numpy 数组在屏蔽时维度损失基本上是同一个问题,但针对的是 PyTorch 张量而不是 NumPy 数组。使用等效的 PyTorch 函数
torch.where
或masked 张量 可以解决这些问题,但我发现 Google 搜索有关 PyTorch 张量的内容很快就没有找到答案。所以,我认为等效的 StackOverflow pytorch 标记问题可能对其他人有用!) 我有一个 2D PyTorch 张量(尽管它可能有更多维度),我想对其应用等效形状的二进制掩码。然而,当我应用掩模时,输出只是一维的。应用掩模后如何保持与原始张量相同的尺寸?
例如,对于
import torch
x = torch.tensor([[1.0, 2.0, 8.0], [-4.0, 0.0, 3.0]])
mask = x >=2.0
print(x[mask])
tensor([2., 8., 3.])
输出现在是 1D 而不是 2D。
torch
的
where
函数,我们将得到行、列 tensors
,如下所示:import torch
x = torch.tensor([[1.0, 2.0, 8.0], [-4.0, 0.0, 3.0]])
mask = x >=2.0
print(torch.where(mask))
print(x[torch.where(mask)])
哪个输出:
(tensor([0, 0, 1]), tensor([1, 2, 2]))
tensor([2., 8., 3.])
但是,将其插入
x
将仅输出
mask
ed 值,让我们得到 1D tensor
,因为我们要删除 tensor
中不计算为 True
的所有值(因此它的形状不能与原始tensor
相同,因为它的值较少,因此它会被展平为单个维度)。如果您希望 x
成为其原始形状,但
only
where
mask
ed 值是 Truth
y 那么我们可以在这些索引处用 1 填充 tensor
的零:import torch
x = torch.tensor([[1.0, 2.0, 8.0], [-4.0, 0.0, 3.0]])
mask = x >=2.0
masked_x = torch.zeros((x.shape))
masked_x[torch.where(mask)] = 1.0
print(masked_x)
输出:
tensor([[0., 1., 1.],
[0., 0., 1.]])
所以现在
masked_x
是由零组成的
tensor
,其形状与 x
相同,但对于 where
,mask
是 Truth
y。如果您希望 masked_x
由
x
的值 where
组成,则 mask
为 Truth
y,则:masked_x = torch.zeros((x.shape))
masked_x[torch.where(mask)] = x[torch.where(mask)]
print(masked_x)
输出:
tensor([[0., 2., 8.],
[0., 0., 3.]])
a
tensor
个零
where
mask
是 False
y。如果您想要其他东西,请澄清。