在 numpy 中计算 A * B * A' (`A @ B @ A.T`) 并保留对称性

问题描述 投票:0回答:1

我想计算矩阵

A * B * A'
A
的项
B

A'
A
的转置。

有没有一种有效的方法可以在Python上计算这个?

我可以做

A @ B @ A.T
但我想要一些东西:

  1. 利用对称性进行计算。
  2. 保证对称的结果。

我有最直接的基于 numba 的代码:

import numpy as np
from numba import jit, njit

A = np.random.rand(3, 3)
B = np.random.rand(3, 3)
B = B + B.T

@njit
def my_fun(A, B, C):
    for i in range(A.shape[0]):
        for j in range(A.shape[1]):
            for k in range(B.shape[1]):
                for l in range(i, A.shape[1]):
                    C[i, l] += A[i, j] * B[j, k] * A[l, k]
                    
    for i in range(1, A.shape[0]):
        for j in range(i):
            C[i, j] = C[j, i]
    
    return

C = np.zeros(shape = (A.shape[0], A.shape[0]))

my_fun(A, B, C)

np.all(C == C.T)

输出是对称的。 我们可以在性能方面做得更好吗?

python numpy performance linear-algebra
1个回答
0
投票

提供的代码在

O(n**4)
中运行,而两个矩阵乘法在
O(n**3)
中运行。因此,2 个矩阵乘法肯定要快得多。人们可以尝试改变循环的顺序,然后分解一些计算,但结果可能类似于 2 乘法矩阵。一种更简单的方法包括编写执行 2 个乘法矩阵的代码,然后通过交换循环进行优化,使其更加适合 SIMD,然后对其进行调整以仅计算最后一个乘法矩阵的上三角部分。中间矩阵实际上可以逐行计算(这对于大型矩阵来说更有效)。下三角部分可以像在提供的实现中那样计算。

这是生成的代码:

@njit
def faster_fun(A, B, C):
    # Constraints coming from the 2 matrix multiplications
    assert A.shape[1] == B.shape[0] and A.shape[1] == B.shape[1]

    n, m = A.shape
    line = np.zeros(m)

    for i in range(n):
        line.fill(0.0)

        for k in range(m):
            factor = A[i, k]
            for j in range(m):
                line[j] += factor * B[k, j]

        for j in range(i, n):
            for k in range(m):
                C[i, j] += line[k] * A[j, k]

    for i in range(1, A.shape[0]):
        for j in range(i):
            C[i, j] = C[j, i]

该解决方案比提供的解决方案要快得多。不过,比

A @ B @ A.T
慢一点。这是因为 Numpy 使用 BLAS 库来计算矩阵乘法,而我的机器上使用的 BLAS 是 OpenBLAS:一个高度优化的实现。 OpenBLAS 以“并行”方式执行矩阵乘法,而上述 Numba 代码是顺序的。如果您计划从多线程代码运行 Numba 函数,那么 Numba 代码将比 Numpy 代码更快。否则,您可以这样并行化循环i
from numba import prange

@njit(parallel=True)
def faster_fun(A, B, C):
    assert A.shape[1] == B.shape[0] and A.shape[1] == B.shape[1]
    n, m = A.shape

    for i in prange(n):
        line = np.zeros(m)

        # [....] (same code later, but not need for line.fill)

    # [...]

此解决方案比大多数机器上的顺序代码更快,除了非常小的矩阵(因为生成线程、分配工作并等待它们不是免费的)。然而,它仍然比我的机器上的基本 Numpy 代码
A @ B @ A.T

慢。这是因为

line
的分配根本无法扩展。事实上,在我的 6 核 CPU 上,并行代码仅比顺序代码快 3.2 倍。 AFAIK,Numba 中还没有简单的解决方案来解决这个(已知)问题。在 Cython 中,解决方案是使用堆栈分配或线程本地分配来解决此问题。尽管如此,即使有完美的缩放,上述代码在 60x60 矩阵上也只会快 30%。这表明要超越高度优化的 BLAS 实现是多么困难,而且对称性只能为中等大小的矩阵提供
较小的性能增益
。对于非常小的矩阵,顺序 Numba 实现应该比 Numpy 代码快得多(由于线程开销)。 注意:第二个矩阵乘法只需要计算一半(最佳),但第一个矩阵乘法肯定需要完全计算。因此,理论上的最佳增益肯定是 25%,这是相当小的(尽管在某些情况下 Numba 可能比这快一点)。

© www.soinside.com 2019 - 2024. All rights reserved.