如何基于掩码进行矩阵相乘和排除元素?

问题描述 投票:0回答:1

我有以下输入矩阵

inp_tensor = torch.tensor(
        [[0.7860, 0.1115, 0.0000, 0.6524, 0.6057, 0.3725, 0.7980, 0.0000],
        [1.0000, 0.1115, 0.0000, 0.6524, 0.6057, 0.3725, 0.0000, 1.0000]])

以及我想要排除的元素的索引(在本例中它们是零元素,但它们可以是任何值

mask_indices = torch.tensor(
[[7, 2],
[2, 6]])

如何从与以下矩阵的乘法中排除这些元素:

my_tensor = torch.tensor(
        [[0.8823, 0.9150, 0.3829],
        [0.9593, 0.3904, 0.6009],
        [0.2566, 0.7936, 0.9408],
        [0.1332, 0.9346, 0.5936],
        [0.8694, 0.5677, 0.7411],
        [0.4294, 0.8854, 0.5739],
        [0.2666, 0.6274, 0.2696],
        [0.4414, 0.2969, 0.8317]])

也就是说,不要将其相乘,包括这些值(本例中为零):

a = torch.mm(inp_tensor, my_tensor)
print(a)
tensor([[1.7866, 2.5468, 1.6330],
        [2.2041, 2.5388, 2.3315]])

我想排除(零)元素(以及

my_tensor
的相应行):

inp_tensor = torch.tensor(
        [[0.7860, 0.1115, 0.6524, 0.6057, 0.3725, 0.7980]]) # remove the elements based on the indices (the zeros here)

my_tensor = torch.tensor(
        [[0.8823, 0.9150, 0.3829],
        [0.9593, 0.3904, 0.6009],
        [0.1332, 0.9346, 0.5936],
        [0.8694, 0.5677, 0.7411],
        [0.4294, 0.8854, 0.5739],
        [0.2666, 0.6274, 0.2696]]) # remove the corresponding zero elements rows

b = torch.mm(inp_tensor, my_tensor)
print(b)
>>> tensor([[1.7866, 2.5468, 1.6330]])

inp_tensor = torch.tensor([[1.0000, 0.1115, 0.6524, 0.6057, 0.3725, 1.0000]]) # remove the elements based on the indices (the zeros here)

my_tensor = torch.tensor(
        [
        [0.8823, 0.9150, 0.3829],                
        [0.9593, 0.3904, 0.6009],
        [0.1332, 0.9346, 0.5936],
        [0.8694, 0.5677, 0.7411],
        [0.4294, 0.8854, 0.5739],
        [0.4414, 0.2969, 0.8317]])  # remove the corresponding zero elements rows

c = torch.mm(inp_tensor, my_tensor)
print(c)
>>> tensor([[2.2041, 2.5388, 2.3315]])
print(torch.cat([b,c]))
>>> tensor([[1.7866, 2.5468, 1.6330],
        [2.2041, 2.5388, 2.3315]])

我需要它是高效的(即,没有

for loops
),因为我的张量非常大,并且还需要保持梯度(即,如果我调用
optimizer.backward()
来更新计算图中的相关参数)

请注意,

inp_tensor
的每一行都具有相同数量的要删除的元素(例如,本示例中的零个元素)。因此,
mask_indices
的每一行也将具有相同数量的元素(例如,本例中为 2)。

更新

我想到的一种方法如下:

import torch

# Given data
mask_indices = torch.tensor([[7, 2], [2, 6]])
inp_tensor = torch.tensor([[0.7860, 0.1115, 0.0000, 0.6524, 0.6057, 0.3725, 0.7980, 0.0000],
                           [1.0000, 0.1115, 0.0000, 0.6524, 0.6057, 0.3725, 0.0000, 1.0000]])
                           
my_tensor = torch.tensor([[0.8823, 0.9150, 0.3829],
                          [0.9593, 0.3904, 0.6009],
                          [0.2566, 0.7936, 0.9408],
                          [0.1332, 0.9346, 0.5936],
                          [0.8694, 0.5677, 0.7411],
                          [0.4294, 0.8854, 0.5739],
                          [0.2666, 0.6274, 0.2696],
                          [0.4414, 0.2969, 0.8317]])

# Duplicate my_tensor 
my_tensors = my_tensor[None].repeat(inp_tensor.size(0), 1, 1)

# Perform element-wise comparison with 0
remove_mask = inp_tensor == 0

# Apply mask to remove rows
filtered_tensors = my_tensors[~remove_mask]

# Reshape to original shape
filtered_tensors = filtered_tensors.view(inp_tensor.size(0), -1, my_tensor.size(-1))
print(filtered_tensors)

result = (inp_tensor.unsqueeze(-1) * my_tensors).sum(dim=1)
print(result)

>>>
tensor([[[0.8823, 0.9150, 0.3829],
         [0.9593, 0.3904, 0.6009],
         [0.1332, 0.9346, 0.5936],
         [0.8694, 0.5677, 0.7411],
         [0.4294, 0.8854, 0.5739],
         [0.2666, 0.6274, 0.2696]],

        [[0.8823, 0.9150, 0.3829],
         [0.9593, 0.3904, 0.6009],
         [0.1332, 0.9346, 0.5936],
         [0.8694, 0.5677, 0.7411],
         [0.4294, 0.8854, 0.5739],
         [0.4414, 0.2969, 0.8317]]])
tensor([[1.7866, 2.5468, 1.6330],
        [2.2041, 2.5388, 2.3315]])

但是

  1. 我不确定这是否非常高效/可扩展,因为我多次复制权重张量(非常大)。
  2. 我不确定这种重复和删除行是否会以有问题的方式影响梯度计算。
  3. 这仍然不仅会乘以
    inp_tensor
    的非零元素,这是可以删除的(例如,使用
    nonzero_values = inp_tensor[inp_tensor != 0].reshape(inp_tensor.shape[0],-1)
    ),但我不确定它是否会以有问题的方式影响梯度计算。
machine-learning matrix pytorch gradient-descent
1个回答
-1
投票

mask_indices
包含列索引,所以你可以这样做:

cols = torch.arange(0, inp_tensor.shape[1])
col_mask_indices = torch.isin(cols, mask_indices, invert=True)
inp_tensor_filtred = inp_tensor[:, col_mask_indices]
my_tensor_filtered = my_tensor[col_mask_indices, :]
torch.mm(inp_tensor_filtred, my_tensor_filtered)

或简化形式:

mask = torch.isin(torch.arange(0, inp_tensor.shape[1]), mask_indices, invert=True)
torch.mm(inp_tensor[:, mask], my_tensor[mask, :])

讨论后编辑: 从数学上讲,您根本不需要排除 0 个元素。为了获得所需的输出,在您的情况下,您可以直接将这 2 个张量相乘,即:

torch.mm(inp_tensor, my_tensor)

对于非0的情况,可以将给定索引的元素都设为0,回到第一个问题。这与消灭它们是一样的。

© www.soinside.com 2019 - 2024. All rights reserved.