Pytorch 张量在遮罩时损失维度

问题描述 投票:0回答:1

坦白说,这与 Numpy 数组在屏蔽时维度损失基本上是同一个问题,但针对的是 PyTorch 张量而不是 NumPy 数组。使用等效的 PyTorch 函数 torch.where

masked 张量 可以解决这些问题,但我发现 Google 搜索有关 PyTorch 张量的内容很快就没有找到答案。所以,我认为等效的 StackOverflow pytorch 标记问题可能对其他人有用!)

我有一个 2D PyTorch 张量(尽管它可能有更多维度),我想对其应用等效形状的二进制掩码。然而,当我应用掩模时,输出只是一维的。应用掩模后如何保持与原始张量相同的尺寸?

例如,对于

import torch x = torch.tensor([[1.0, 2.0, 8.0], [-4.0, 0.0, 3.0]]) mask = x >=2.0 print(x[mask]) tensor([2., 8., 3.])

输出现在是 1D 而不是 2D。

pytorch tensor
1个回答
0
投票
torch

where
函数,我们将得到行、列
tensors
,如下所示:
import torch

x = torch.tensor([[1.0, 2.0, 8.0], [-4.0, 0.0, 3.0]])
mask = x >=2.0
print(torch.where(mask))
print(x[torch.where(mask)])

哪个输出:

(tensor([0, 0, 1]), tensor([1, 2, 2])) tensor([2., 8., 3.])

但是,将其插入 
x

将仅输出

mask
ed 值,让我们得到 1D
tensor
,因为我们要删除
tensor
中不计算为
True
的所有值(因此它的形状不能与原始
tensor
相同,因为它的值较少,因此它会被展平为单个维度)。
如果您希望 

x

成为其原始形状,但

only
where
mask
ed 值是
Truth
y 那么我们可以在这些索引处用 1 填充
tensor
的零:
import torch

x = torch.tensor([[1.0, 2.0, 8.0], [-4.0, 0.0, 3.0]])
mask = x >=2.0
masked_x = torch.zeros((x.shape))
masked_x[torch.where(mask)] = 1.0
print(masked_x)

输出:

tensor([[0., 1., 1.], [0., 0., 1.]])

所以现在 
masked_x

是由零组成的

tensor
,其形状与
x
相同,但对于
where
mask
Truth
y。
如果您希望 

masked_x

x
的值
where
组成,则
mask
Truth
y,则:
masked_x = torch.zeros((x.shape))
masked_x[torch.where(mask)] = x[torch.where(mask)]
print(masked_x)

输出:

tensor([[0., 2., 8.], [0., 0., 3.]])

a 
tensor

个零

where
mask
False
y。
如果您想要其他东西,请澄清。

© www.soinside.com 2019 - 2024. All rights reserved.