为什么 Numba 的矩阵乘法很慢?

问题描述 投票:0回答:2

我试图找到一个解释,为什么我使用 Numba 的矩阵乘法比使用 NumPy 的点函数慢得多。尽管我使用的是最基本的代码来编写带有 Numba 的矩阵乘法函数,但我不认为显着降低的性能是算法造成的。为简单起见,我考虑两个 k x k 方阵,A 和 B。我的代码是这样的

1     @njit('f8[:,:](f8[:,:], f8[:,:])')
2     def numba_dot(A, B):
3
4         k=A.shape[1]
5         C = np.zeros((k, k))
6
7         for i in range(k):
8             for j in range(k):
9
10                 tmp = 0.
11                for l in range(k):
12                    tmp += A[i, l] * B[l, j]
13     
14                C[i, j] = tmp
15
16         return C

使用两个随机矩阵 1000 x 1000 矩阵重复运行此代码,通常至少需要大约 1.5 秒才能完成。 另一方面,如果我不更新矩阵 C,即如果我删除第 14 行,或者为了测试将其替换为例如以下行:

14                C[i, j] = i * j

代码在大约 1-5 毫秒内完成。相比之下,NumPy 的点函数需要大约 10 毫秒的矩阵乘法。

上述矩阵乘法代码与这个小变化之间的运行时间差异背后的原因是什么?有没有一种方法可以将变量 tmp 的值存储在 C[i, j] 中,而不会显着降低代码的性能?

python numpy numba
2个回答
2
投票

本机

NumPy
实现与矢量化操作一起使用。如果您的 CPU 支持这些,处理速度将快。当前的微处理器具有片上矩阵乘法,可通过管道传输数据传输和向量运算。

你的实现执行了 k^3 次循环迭代;十亿的任何东西都需要一些不平凡的时间。 您的代码指定您要单独执行每个单元格操作,十亿个不同的操作,而不是并行和流水线完成的大约 5k 个操作。


0
投票

对于整数,numpy 出于某种原因不使用 BLAS。 来源

import numpy as np
from numba import njit

def matrix_multiplication(A, B):
  m, n = A.shape
  _, p = B.shape
  C = np.zeros((m, p))
  for i in range(m):
    for j in range(n):
      for k in range(p):
        C[i, k] += A[i, j] * B[j, k]
  return C

@njit()
def matrix_multiplication_optimized(A, B):
  m, n = A.shape
  _, p = B.shape
  C = np.zeros((m, p))
  for i in range(m):
    for j in range(n):
      for k in range(p):
        C[i, k] += A[i, j] * B[j, k]
  return C

m = 100
n = 100
p = 100
A = np.random.randint(1, 100, size=(m,n))
B = np.random.randint(1, 100, size=(n, p))

# compile function
matrix_multiplication_optimized(A, B)

%timeit matrix_multiplication(A, B)
%timeit matrix_multiplication_optimized(A, B)
%timeit A @ B
685 ms ± 7.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.34 ms ± 5.51 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.49 ms ± 19.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

在这种情况下,numba 甚至比 numpy 快一点点。这让我认为 numba 正在生成使用矢量化的代码,同时也是缓存友好的(python 代码无法进一步改进)。其他循环顺序更糟,所以我可能在没有意识到的情况下使用了正确的缓存友好循环顺序。

@njit()
def matrix_multiplication_optimized2(A, B):
  m, n = A.shape
  _, p = B.shape
  C = np.zeros((m, p))
  for j in range(n):
    for k in range(p):
      for i in range(m):
        C[i, k] += A[i, j] * B[j, k]
  return C

@njit()
def matrix_multiplication_optimized3(A, B):
  m, n = A.shape
  _, p = B.shape
  C = np.zeros((m, p))
  for k in range(p):
    for i in range(m):
      for j in range(n):
        C[i, k] += A[i, j] * B[j, k]
  return C
m = 1000
n = 1000
p = 1000
A = np.random.randn(m, n)
B = np.random.randn(n, p)

# compile function
matrix_multiplication_optimized(A, B)
matrix_multiplication_optimized2(A, B)
matrix_multiplication_optimized3(A, B)


%timeit matrix_multiplication_optimized(A, B)
%timeit matrix_multiplication_optimized2(A, B)
%timeit matrix_multiplication_optimized3(A, B)
%timeit A @ B
1.45 s ± 30.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
12.6 s ± 92 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.93 s ± 35.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
30 ms ± 1.97 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

根据我的经验,对于浮点数,numpy 比 numba 快大约 50 倍。 这个问题展示了使用 BLAS 如何提高性能。 numba documentation 最后提到了 BLAS,但我不知道如何使用

numpy.linalg
.

注释掉行

C[i, j] = tmp
使临时变量无用。

@njit('f8[:,:](f8[:,:], f8[:,:])')
def numba_dot(A, B):

    k=A.shape[1]
    C = np.zeros((k, k))

    for i in range(k):
        for j in range(k):
          tmp = 0.
          for l in range(k):
              tmp += A[i, l] * B[l, j]

          # C[i, j] = tmp

    return C

@njit('f8[:,:](f8[:,:], f8[:,:])')
def numba_dot2(A, B):

    k=A.shape[1]
    C = np.zeros((k, k))

    for i in range(k):
        for j in range(k):
          # tmp = 0.
          for l in range(k):
              # tmp += A[i, l] * B[l, j]
              pass

          # C[i, j] = tmp

    return C

%timeit numba_dot(A, B)
%timeit numba_dot2(A, B)
for k, v in numba_dot.inspect_asm().items():
  print(k, v)
2.59 ms ± 158 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.6 ms ± 93.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

我无法阅读生成的代码,但临时变量可能在优化过程中被删除,因为它没有被使用。

C[i, j] = i * j
可以相对快速地执行。请记住,正在使用矢量化操作。

4.18 ms ± 218 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

你的实现比我的慢,所以我尝试颠倒 l 和 j。

@njit('f8[:,:](f8[:,:], f8[:,:])')
def numba_dot(A, B):

    k=A.shape[1]
    C = np.zeros((k, k))

    for i in range(k):
        for j in range(k):
          tmp = 0.
          for l in range(k):
              tmp += A[i, l] * B[l, j]

          C[i, j] = tmp

    return C

@njit('f8[:,:](f8[:,:], f8[:,:])')
def numba_dot2(A, B):

    k=A.shape[1]
    C = np.zeros((k, k))

    for i in range(k):
        for l in range(k):
          tmp = 0.
          for j in range(k):
              tmp += A[i, l] * B[l, j]
              C[i, j] = tmp

    return C



%timeit numba_dot(A, B)
%timeit numba_dot2(A, B)
3.16 s ± 36.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.57 s ± 24.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

这样做时,保留一个临时变量并没有什么意义,因为 j 是最后一个循环。我没有看到直接更新 C[i, j] 有任何问题。

© www.soinside.com 2019 - 2024. All rights reserved.