我有以下输入矩阵
inp_tensor = torch.tensor(
[[0.7860, 0.1115, 0.0000, 0.6524, 0.6057, 0.3725, 0.7980, 0.0000],
[1.0000, 0.1115, 0.0000, 0.6524, 0.6057, 0.3725, 0.0000, 1.0000]])
以及我想要排除的元素的索引(在本例中它们是零元素,但它们可以是任何值)
mask_indices = torch.tensor(
[[7, 2],
[2, 6]])
如何从与以下矩阵的乘法中排除这些元素:
my_tensor = torch.tensor(
[[0.8823, 0.9150, 0.3829],
[0.9593, 0.3904, 0.6009],
[0.2566, 0.7936, 0.9408],
[0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411],
[0.4294, 0.8854, 0.5739],
[0.2666, 0.6274, 0.2696],
[0.4414, 0.2969, 0.8317]])
也就是说,不要将其相乘,包括这些值(本例中为零):
a = torch.mm(inp_tensor, my_tensor)
print(a)
tensor([[1.7866, 2.5468, 1.6330],
[2.2041, 2.5388, 2.3315]])
我想排除(零)元素(以及
my_tensor
的相应行):
inp_tensor = torch.tensor(
[[0.7860, 0.1115, 0.6524, 0.6057, 0.3725, 0.7980]]) # remove the elements based on the indices (the zeros here)
my_tensor = torch.tensor(
[[0.8823, 0.9150, 0.3829],
[0.9593, 0.3904, 0.6009],
[0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411],
[0.4294, 0.8854, 0.5739],
[0.2666, 0.6274, 0.2696]]) # remove the corresponding zero elements rows
b = torch.mm(inp_tensor, my_tensor)
print(b)
>>> tensor([[1.7866, 2.5468, 1.6330]])
inp_tensor = torch.tensor([[1.0000, 0.1115, 0.6524, 0.6057, 0.3725, 1.0000]]) # remove the elements based on the indices (the zeros here)
my_tensor = torch.tensor(
[
[0.8823, 0.9150, 0.3829],
[0.9593, 0.3904, 0.6009],
[0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411],
[0.4294, 0.8854, 0.5739],
[0.4414, 0.2969, 0.8317]]) # remove the corresponding zero elements rows
c = torch.mm(inp_tensor, my_tensor)
print(c)
>>> tensor([[2.2041, 2.5388, 2.3315]])
print(torch.cat([b,c]))
>>> tensor([[1.7866, 2.5468, 1.6330],
[2.2041, 2.5388, 2.3315]])
我需要它是高效的(即,没有
for loops
),因为我的张量非常大,并且还需要保持梯度(即,如果我调用 optimizer.backward()
来更新计算图中的相关参数)
请注意,
inp_tensor
的每一行都具有相同数量的要删除的元素(例如,本示例中的零个元素)。因此,mask_indices
的每一行也将具有相同数量的元素(例如,本例中为 2)。
更新
我想到的一种方法如下:
import torch
# Given data
mask_indices = torch.tensor([[7, 2], [2, 6]])
inp_tensor = torch.tensor([[0.7860, 0.1115, 0.0000, 0.6524, 0.6057, 0.3725, 0.7980, 0.0000],
[1.0000, 0.1115, 0.0000, 0.6524, 0.6057, 0.3725, 0.0000, 1.0000]])
my_tensor = torch.tensor([[0.8823, 0.9150, 0.3829],
[0.9593, 0.3904, 0.6009],
[0.2566, 0.7936, 0.9408],
[0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411],
[0.4294, 0.8854, 0.5739],
[0.2666, 0.6274, 0.2696],
[0.4414, 0.2969, 0.8317]])
# Duplicate my_tensor
my_tensors = my_tensor[None].repeat(inp_tensor.size(0), 1, 1)
# Perform element-wise comparison with 0
remove_mask = inp_tensor == 0
# Apply mask to remove rows
filtered_tensors = my_tensors[~remove_mask]
# Reshape to original shape
filtered_tensors = filtered_tensors.view(inp_tensor.size(0), -1, my_tensor.size(-1))
print(filtered_tensors)
result = (inp_tensor.unsqueeze(-1) * my_tensors).sum(dim=1)
print(result)
>>>
tensor([[[0.8823, 0.9150, 0.3829],
[0.9593, 0.3904, 0.6009],
[0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411],
[0.4294, 0.8854, 0.5739],
[0.2666, 0.6274, 0.2696]],
[[0.8823, 0.9150, 0.3829],
[0.9593, 0.3904, 0.6009],
[0.1332, 0.9346, 0.5936],
[0.8694, 0.5677, 0.7411],
[0.4294, 0.8854, 0.5739],
[0.4414, 0.2969, 0.8317]]])
tensor([[1.7866, 2.5468, 1.6330],
[2.2041, 2.5388, 2.3315]])
但是
inp_tensor
的非零元素,这是可以删除的(例如,使用nonzero_values = inp_tensor[inp_tensor != 0].reshape(inp_tensor.shape[0],-1)
),但我不确定它是否会以有问题的方式影响梯度计算。mask_indices
包含列索引,所以你可以这样做:
cols = torch.arange(0, inp_tensor.shape[1])
col_mask_indices = torch.isin(cols, mask_indices, invert=True)
inp_tensor_filtred = inp_tensor[:, col_mask_indices]
my_tensor_filtered = my_tensor[col_mask_indices, :]
torch.mm(inp_tensor_filtred, my_tensor_filtered)
或简化形式:
mask = torch.isin(torch.arange(0, inp_tensor.shape[1]), mask_indices, invert=True)
torch.mm(inp_tensor[:, mask], my_tensor[mask, :])
讨论后编辑: 从数学上讲,您根本不需要排除 0 个元素。为了获得所需的输出,在您的情况下,您可以直接将这 2 个张量相乘,即:
torch.mm(inp_tensor, my_tensor)
对于非0的情况,可以将给定索引的元素都设为0,回到第一个问题。这与消灭它们是一样的。