在PyTorch中,我想计算一下
E * A.mm(B)
其中E可以是一个由0和1组成的非常稀疏的矩阵。换句话说,我想计算A.mm(B),然后只留下某个坐标。有没有一种方法可以高效地计算这样一个稀疏矩阵?我完全可以控制矩阵的表示方式。
另外,在大多数情况下,E 只由 1 组成,所以我希望这种情况也能高效处理。
你不需要元素乘法,因为 E
本质上是一个布尔矩阵,用作选择值的掩码,其中的 E
为1,丢弃其中 E
是0。
C = A.mm(B)
# Ensure that E is a boolean matrix to only keep values where E is True,
# otherwise the 0s and 1s would be treated as indices to select the values.
C = C[E.to(torch.bool)]
如果你想避免整个矩阵乘法,而只计算之后要屏蔽的值,你需要手动为 A
和 B
中产生所需数值。C
.
矩阵乘法 C = AB,其中 A 是一个 m x n 矩阵和 B 一个 n x p 矩阵,产生一个 m x p 矩阵 C,其数值是通过乘以......得到的。i-th 一排 A 随着 拇指 一列 B 元素,并取其总和 n 产品。形式上。
给定 E,一个 m x p 矩阵,决定哪些元素的 C 是需要的,所需元素的索引对给出如下。
# Indices of required elements (i.e. indices of non-zero elements of E)
# Separate the tensor of (i, j) pairs, into a pair of tensors,
# containing the indices i and j respectively.
indices_i, indices_j = E.nonzero().unbind(dim=1)
# Select all needed rows of A and the needed columns of B
A = A[indices_i]
B = B[:, indices_j]
# Calculate the values
# B is transposed to change the column vectors to row vectors
# such that the two can be multiplied element-wise.
C = torch.sum(A * B.transpose(0, 1), dim=1)
选择性地计算出你想要的值 与进行整个矩阵乘法然后只保留你想要的值相比是否更有效率?
答案是肯定的 没有. 矩阵乘法是高度优化的,比手动做步骤与操作本身就优化得多。特别是,当 E 含有大部分的1,那么你基本上是重新实现了矩阵乘法,这保证了效率的降低。即使是在 E 大部分都是0,矩阵乘法只是更快。
为了支持我的说法,我对它们进行了计时。为了方便起见,我在IPython中做了,因为IPython有内置的 %timeit
命令。
In [1]: import torch
...:
...:
...: def masked(A, B, E):
...: C = A.mm(B)
...: return C[E]
...:
...:
...: def selective(A, B, E):
...: indices_i, indices_j = E.nonzero().unbind(dim=1)
...: return torch.sum(A[indices_i] * B[:, indices_j].transpose(0, 1), dim=1)
...:
...:
...: A = torch.rand(1200, 1000)
...: B = torch.rand(1000, 1100)
...: # Only 10% of the elements are 1
...: E_mostly_zeros = torch.rand(1200, 1100) < 0.1
...: # 90% of the elements are 1
...: E_mostly_ones = torch.rand(1200, 1100) < 0.9
In [2]: # All close instead of equal to account for floating point errors
...: torch.allclose(masked(A, B, E_mostly_ones), selective(A, B, E_mostly_ones))
Out[2]: True
In [3]: # All close instead of equal to account for floating point errors
...: torch.allclose(masked(A, B, E_mostly_zeros), selective(A, B, E_mostly_zeros))
Out[3]: True
In [4]: %timeit masked(A, B, E_mostly_ones)
8.16 ms ± 20.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [5]: %timeit selective(A, B, E_mostly_ones)
2.09 s ± 11.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [6]: %timeit masked(A, B, E_mostly_zeros)
5.73 ms ± 24.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
] In [7]: %timeit selective(A, B, E_mostly_zeros)
266 ms ± 3.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
矩阵乘法的速度快得惊人,是超过了 256x 更快 E 含有90%的(8.16ms vs 2090ms),超过了 46x 更快 E 只包含10%的(5.73ms vs 266ms)。