我想计算矩阵
A * B * A'
和 A
的项 B
。A'
是 A
的转置。
有没有一种有效的方法可以在Python上计算这个?
我可以做
A @ B @ A.T
但我想要一些东西:
我有最直接的基于 numba 的代码:
import numpy as np
from numba import jit, njit
A = np.random.rand(3, 3)
B = np.random.rand(3, 3)
B = B + B.T
@njit
def my_fun(A, B, C):
for i in range(A.shape[0]):
for j in range(A.shape[1]):
for k in range(B.shape[1]):
for l in range(i, A.shape[1]):
C[i, l] += A[i, j] * B[j, k] * A[l, k]
for i in range(1, A.shape[0]):
for j in range(i):
C[i, j] = C[j, i]
return
C = np.zeros(shape = (A.shape[0], A.shape[0]))
my_fun(A, B, C)
np.all(C == C.T)
输出是对称的。 我们可以在性能方面做得更好吗?
提供的代码在
O(n**4)
中运行,而两个矩阵乘法在 O(n**3)
中运行。因此,2 个矩阵乘法肯定要快得多。人们可以尝试改变循环的顺序,然后分解一些计算,但结果可能类似于 2 乘法矩阵。一种更简单的方法包括编写执行 2 个乘法矩阵的代码,然后通过交换循环进行优化,使其更加适合 SIMD,然后对其进行调整以仅计算最后一个乘法矩阵的上三角部分。中间矩阵实际上可以逐行计算(这对于大型矩阵来说更有效)。下三角部分可以像在提供的实现中那样计算。
这是生成的代码:
@njit
def faster_fun(A, B, C):
# Constraints coming from the 2 matrix multiplications
assert A.shape[1] == B.shape[0] and A.shape[1] == B.shape[1]
n, m = A.shape
line = np.zeros(m)
for i in range(n):
line.fill(0.0)
for k in range(m):
factor = A[i, k]
for j in range(m):
line[j] += factor * B[k, j]
for j in range(i, n):
for k in range(m):
C[i, j] += line[k] * A[j, k]
for i in range(1, A.shape[0]):
for j in range(i):
C[i, j] = C[j, i]
该解决方案比提供的解决方案要快得多。不过,比
A @ B @ A.T
慢一点。这是因为 Numpy 使用 BLAS 库来计算矩阵乘法,而我的机器上使用的 BLAS 是 OpenBLAS:一个高度优化的实现。 OpenBLAS 以“并行”方式执行矩阵乘法,而上述 Numba 代码是顺序的。如果您计划从多线程代码运行 Numba 函数,那么 Numba 代码将比 Numpy 代码更快。否则,您可以这样并行化循环i
:from numba import prange
@njit(parallel=True)
def faster_fun(A, B, C):
assert A.shape[1] == B.shape[0] and A.shape[1] == B.shape[1]
n, m = A.shape
for i in prange(n):
line = np.zeros(m)
# [....] (same code later, but not need for line.fill)
# [...]
此解决方案比大多数机器上的顺序代码更快,除了非常小的矩阵(因为生成线程、分配工作并等待它们不是免费的)。然而,它仍然比我的机器上的基本 Numpy 代码
A @ B @ A.T
慢。这是因为
line
的分配根本无法扩展。事实上,在我的 6 核 CPU 上,并行代码仅比顺序代码快 3.2 倍。 AFAIK,Numba 中还没有简单的解决方案来解决这个(已知)问题。在 Cython 中,解决方案是使用堆栈分配或线程本地分配来解决此问题。尽管如此,即使有完美的缩放,上述代码在 60x60 矩阵上也只会快 30%。这表明要超越高度优化的 BLAS 实现是多么困难,而且对称性只能为中等大小的矩阵提供较小的性能增益。对于非常小的矩阵,顺序 Numba 实现应该比 Numpy 代码快得多(由于线程开销)。 注意:第二个矩阵乘法只需要计算一半(最佳),但第一个矩阵乘法肯定需要完全计算。因此,理论上的最佳增益肯定是 25%,这是相当小的(尽管在某些情况下 Numba 可能比这快一点)。