如何用二维遮罩遮蔽一个三维张量,并保持原始矢量的尺寸?

问题描述 投票:0回答:1

假设,我有一个三维张量 A

A = torch.arange(24).view(4, 3, 2)
print(A)

并要求用二维张量对其进行遮挡

mask = torch.zeros((4, 3), dtype=torch.int64)  # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)

使用 PyTorch 的 masked_select 功能会导致以下错误。

torch.masked_select(X, (mask == 1))


---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-72-fd6809d2c4cc> in <module>
     12 
     13 # Select based on new mask
---> 14 Y = torch.masked_select(X, (mask == 1))
     15 #Y = X * mask_
     16 print(Y)

RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 2

如何用 2D 遮罩遮罩 3D 张量,并保持原始向量的尺寸?希望得到任何提示。

pytorch tensor
1个回答
0
投票

基本上,我们需要将张量掩码的维度与被掩码的张量相匹配。

有两种方法可以做到这一点。

方法1:不保留原始张量尺寸。

X = torch.arange(24).view(4, 3, 2)
print(X)

mask = torch.zeros((4, 3), dtype=torch.int64)  # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)

# Add a dimension to the mask tensor and expand it to the size of original tensor
mask_ = mask.unsqueeze(-1).expand(X.size())
print(mask_)

# Select based on the new expanded mask
Y = torch.masked_select(X, (mask_ == 1)) # does not preserve the dims
print(Y)

方法1的输出。

tensor([ 0,  1,  8,  9, 18, 19])

方法2:保留原始张量尺寸(通过填充)。

X = torch.arange(24).view(4, 3, 2)
print(X)

mask = torch.zeros((4, 3), dtype=torch.int64)  # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)

# Add a dimension to the mask tensor and expand it to the size of original tensor
mask_ = mask.unsqueeze(-1).expand(X.size())
print(mask_)

# Select based on the new expanded mask
Y = X * mask_
print(Y)

方法2的输出。

tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5]],

        [[ 6,  7],
         [ 8,  9],
         [10, 11]],

        [[12, 13],
         [14, 15],
         [16, 17]],

        [[18, 19],
         [20, 21],
         [22, 23]]])
Mask:  tensor([[1, 0, 0],
        [0, 1, 0],
        [0, 0, 0],
        [1, 0, 0]])
tensor([[[1, 1],
         [0, 0],
         [0, 0]],

        [[0, 0],
         [1, 1],
         [0, 0]],

        [[0, 0],
         [0, 0],
         [0, 0]],

        [[1, 1],
         [0, 0],
         [0, 0]]])
tensor([[[ 0,  1],
         [ 0,  0],
         [ 0,  0]],

        [[ 0,  0],
         [ 8,  9],
         [ 0,  0]],

        [[ 0,  0],
         [ 0,  0],
         [ 0,  0]],

        [[18, 19],
         [ 0,  0],
         [ 0,  0]]]
© www.soinside.com 2019 - 2024. All rights reserved.