考虑以下两种对 2d numpy 数组中的所有值求和的方法。
import numpy as np
from numba import njit
a = np.random.rand(2, 5000)
@njit(fastmath=True, cache=True)
def sum_array_slow(arr):
s = 0
for i in range(arr.shape[0]):
for j in range(arr.shape[1]):
s += arr[i, j]
return s
@njit(fastmath=True, cache=True)
def sum_array_fast(arr):
s = 0
for i in range(arr.shape[1]):
s += arr[0, i]
for i in range(arr.shape[1]):
s += arr[1, i]
return s
查看 sum_array_slow 中的嵌套循环,它似乎应该以与 sum_array_fast 相同的顺序执行完全相同的操作。然而:
In [46]: %timeit sum_array_slow(a)
7.7 µs ± 374 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [47]: %timeit sum_array_fast(a)
951 ns ± 2.63 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
为什么 sum_array_fast 函数比 sum_array_slow 快 8 倍,而它似乎会以相同的顺序执行相同的计算?
这是因为慢速版本不自动向量化(即编译器无法生成快速 SIMD 代码),而快速版本是。这肯定是因为 Numba 在第一个循环中没有优化索引环绕,所以它是 Numba 的miss optimization
这个可以通过分析汇编代码看出。这是慢速版本的热循环:
.LBB0_6:
addq %rbx, %rdx
vaddsd (%rax,%rdx,8), %xmm0, %xmm0
leaq 1(%rsi), %rdx
cmpq $1, %rbp
cmovleq %r13, %rdx
addq %rbx, %rdx
vaddsd (%rax,%rdx,8), %xmm0, %xmm0
leaq 2(%rsi), %rdx
cmpq $2, %rbp
cmovleq %r13, %rdx
addq %rbx, %rdx
vaddsd (%rax,%rdx,8), %xmm0, %xmm0
leaq 3(%rsi), %rdx
cmpq $3, %rbp
cmovleq %r13, %rdx
addq $4, %rsi
leaq -4(%rbp), %rdi
addq %rbx, %rdx
vaddsd (%rax,%rdx,8), %xmm0, %xmm0
cmpq $4, %rbp
movl $0, %edx
cmovgq %rsi, %rdx
movq %rdi, %rbp
cmpq %rsi, %r12
jne .LBB0_6
我们可以看到 Numba 产生了许多无用的索引检查,这使得循环非常低效。我不知道有什么干净的方法可以解决这个问题。这很可悲,因为这样的问题在实践中并不罕见。使用像 C 和 C++ 这样的本地语言可以解决这个问题(因为数组中没有索引包装)。一种不安全/丑陋的方法是在 Numba 中使用指针,但提取 Numpy 数据指针并将其提供给 Numba 似乎很痛苦(如果可能的话)。
这是快速的:
.LBB0_8:
vaddpd (%r11,%rsi,8), %ymm0, %ymm0
vaddpd 32(%r11,%rsi,8), %ymm1, %ymm1
vaddpd 64(%r11,%rsi,8), %ymm2, %ymm2
vaddpd 96(%r11,%rsi,8), %ymm3, %ymm3
vaddpd 128(%r11,%rsi,8), %ymm0, %ymm0
vaddpd 160(%r11,%rsi,8), %ymm1, %ymm1
vaddpd 192(%r11,%rsi,8), %ymm2, %ymm2
vaddpd 224(%r11,%rsi,8), %ymm3, %ymm3
vaddpd 256(%r11,%rsi,8), %ymm0, %ymm0
vaddpd 288(%r11,%rsi,8), %ymm1, %ymm1
vaddpd 320(%r11,%rsi,8), %ymm2, %ymm2
vaddpd 352(%r11,%rsi,8), %ymm3, %ymm3
vaddpd 384(%r11,%rsi,8), %ymm0, %ymm0
vaddpd 416(%r11,%rsi,8), %ymm1, %ymm1
vaddpd 448(%r11,%rsi,8), %ymm2, %ymm2
vaddpd 480(%r11,%rsi,8), %ymm3, %ymm3
addq $64, %rsi
addq $-4, %rdi
jne .LBB0_8
在这种情况下,循环得到了很好的优化。事实上,它几乎是大型阵列的最佳选择。对于小型阵列,就像在您的示例中一样,它在像我这样的某些处理器上并不是最佳的。事实上,AFAIK,展开的指令没有使用足够的寄存器来隐藏 FMA 单元的延迟(这是因为 LLVM 在内部生成次优代码)。可能需要较低级别的本机代码来修复此问题(至少,在 Numba 中没有简单的方法来修复此问题)。