仅对给定坐标计算矩阵乘法。

问题描述 投票:0回答:1

在PyTorch中,我想计算一下

E * A.mm(B)

其中E可以是一个由0和1组成的非常稀疏的矩阵。换句话说,我想计算A.mm(B),然后只留下某个坐标。有没有一种方法可以高效地计算这样一个稀疏矩阵?我完全可以控制矩阵的表示方式。

另外,在大多数情况下,E 只由 1 组成,所以我希望这种情况也能高效处理。

python-3.x pytorch sparse-matrix
1个回答
3
投票

你不需要元素乘法,因为 E 本质上是一个布尔矩阵,用作选择值的掩码,其中的 E 为1,丢弃其中 E 是0。

C = A.mm(B)

# Ensure that E is a boolean matrix to only keep values where E is True,
# otherwise the 0s and 1s would be treated as indices to select the values.
C = C[E.to(torch.bool)]

如果你想避免整个矩阵乘法,而只计算之后要屏蔽的值,你需要手动为 AB 中产生所需数值。C.

矩阵乘法 C = AB,其中 A 是一个 m x n 矩阵和 B 一个 n x p 矩阵,产生一个 m x p 矩阵 C,其数值是通过乘以......得到的。i-th 一排 A 随着 拇指 一列 B 元素,并取其总和 n 产品。形式上。

Matrix Multiplication Formula

给定 E,一个 m x p 矩阵,决定哪些元素的 C 是需要的,所需元素的索引对给出如下。

Indices

# Indices of required elements (i.e. indices of non-zero elements of E)
# Separate the tensor of (i, j) pairs, into a pair of tensors,
# containing the indices i and j respectively.
indices_i, indices_j = E.nonzero().unbind(dim=1)

# Select all needed rows of A and the needed columns of B
A = A[indices_i]
B = B[:, indices_j]

# Calculate the values
# B is transposed to change the column vectors to row vectors
# such that the two can be multiplied element-wise.
C = torch.sum(A * B.transpose(0, 1), dim=1)

选择性地计算出你想要的值 与进行整个矩阵乘法然后只保留你想要的值相比是否更有效率?

答案是肯定的 没有. 矩阵乘法是高度优化的,比手动做步骤与操作本身就优化得多。特别是,当 E 含有大部分的1,那么你基本上是重新实现了矩阵乘法,这保证了效率的降低。即使是在 E 大部分都是0,矩阵乘法只是更快。

为了支持我的说法,我对它们进行了计时。为了方便起见,我在IPython中做了,因为IPython有内置的 %timeit 命令。

In [1]: import torch
   ...:
   ...:
   ...: def masked(A, B, E):
   ...:     C = A.mm(B)
   ...:     return C[E]
   ...:
   ...:
   ...: def selective(A, B, E):
   ...:     indices_i, indices_j = E.nonzero().unbind(dim=1)
   ...:     return torch.sum(A[indices_i] * B[:, indices_j].transpose(0, 1), dim=1)
   ...:
   ...:
   ...: A = torch.rand(1200, 1000)
   ...: B = torch.rand(1000, 1100)
   ...: # Only 10% of the elements are 1
   ...: E_mostly_zeros = torch.rand(1200, 1100) < 0.1
   ...: # 90% of the elements are 1
   ...: E_mostly_ones = torch.rand(1200, 1100) < 0.9

In [2]: # All close instead of equal to account for floating point errors
   ...: torch.allclose(masked(A, B, E_mostly_ones), selective(A, B, E_mostly_ones))
Out[2]: True

In [3]: # All close instead of equal to account for floating point errors
   ...: torch.allclose(masked(A, B, E_mostly_zeros), selective(A, B, E_mostly_zeros))
Out[3]: True

In [4]: %timeit masked(A, B, E_mostly_ones)
8.16 ms ± 20.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [5]: %timeit selective(A, B, E_mostly_ones)
2.09 s ± 11.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [6]: %timeit masked(A, B, E_mostly_zeros)
5.73 ms ± 24.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

] In [7]: %timeit selective(A, B, E_mostly_zeros)
266 ms ± 3.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

矩阵乘法的速度快得惊人,是超过了 256x 更快 E 含有90%的(8.16ms vs 2090ms),超过了 46x 更快 E 只包含10%的(5.73ms vs 266ms)。

© www.soinside.com 2019 - 2024. All rights reserved.