为什么这两种不同的二维数组求和方式会有如此不同的性能?

问题描述 投票:0回答:1

考虑以下两种对 2d numpy 数组中的所有值求和的方法。

import numpy as np
from numba import njit
a = np.random.rand(2, 5000)

@njit(fastmath=True, cache=True)
def sum_array_slow(arr):
    s = 0
    for i in range(arr.shape[0]):
        for j in range(arr.shape[1]):
            s += arr[i, j]
    return s
    
@njit(fastmath=True, cache=True)
def sum_array_fast(arr):
    s = 0
    for i in range(arr.shape[1]):
        s += arr[0, i]
    for i in range(arr.shape[1]):
        s += arr[1, i]
    return s

查看 sum_array_slow 中的嵌套循环,它似乎应该以与 sum_array_fast 相同的顺序执行完全相同的操作。然而:

In [46]: %timeit sum_array_slow(a)
7.7 µs ± 374 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [47]: %timeit sum_array_fast(a)
951 ns ± 2.63 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

为什么 sum_array_fast 函数比 sum_array_slow 快 8 倍,而它似乎会以相同的顺序执行相同的计算?

python numpy numba
1个回答
1
投票

这是因为慢速版本不自动向量化(即编译器无法生成快速 SIMD 代码),而快速版本是。这肯定是因为 Numba 在第一个循环中没有优化索引环绕,所以它是 Numba 的miss optimization

这个可以通过分析汇编代码看出。这是慢速版本的热循环:

.LBB0_6:
    addq    %rbx, %rdx
    vaddsd  (%rax,%rdx,8), %xmm0, %xmm0
    leaq    1(%rsi), %rdx
    cmpq    $1, %rbp
    cmovleq %r13, %rdx
    addq    %rbx, %rdx
    vaddsd  (%rax,%rdx,8), %xmm0, %xmm0
    leaq    2(%rsi), %rdx
    cmpq    $2, %rbp
    cmovleq %r13, %rdx
    addq    %rbx, %rdx
    vaddsd  (%rax,%rdx,8), %xmm0, %xmm0
    leaq    3(%rsi), %rdx
    cmpq    $3, %rbp
    cmovleq %r13, %rdx
    addq    $4, %rsi
    leaq    -4(%rbp), %rdi
    addq    %rbx, %rdx
    vaddsd  (%rax,%rdx,8), %xmm0, %xmm0
    cmpq    $4, %rbp
    movl    $0, %edx
    cmovgq  %rsi, %rdx
    movq    %rdi, %rbp
    cmpq    %rsi, %r12
    jne .LBB0_6

我们可以看到 Numba 产生了许多无用的索引检查,这使得循环非常低效。我不知道有什么干净的方法可以解决这个问题。这很可悲,因为这样的问题在实践中并不罕见。使用像 C 和 C++ 这样的本地语言可以解决这个问题(因为数组中没有索引包装)。一种不安全/丑陋的方法是在 Numba 中使用指针,但提取 Numpy 数据指针并将其提供给 Numba 似乎很痛苦(如果可能的话)。

这是快速的:

.LBB0_8:
    vaddpd  (%r11,%rsi,8), %ymm0, %ymm0
    vaddpd  32(%r11,%rsi,8), %ymm1, %ymm1
    vaddpd  64(%r11,%rsi,8), %ymm2, %ymm2
    vaddpd  96(%r11,%rsi,8), %ymm3, %ymm3
    vaddpd  128(%r11,%rsi,8), %ymm0, %ymm0
    vaddpd  160(%r11,%rsi,8), %ymm1, %ymm1
    vaddpd  192(%r11,%rsi,8), %ymm2, %ymm2
    vaddpd  224(%r11,%rsi,8), %ymm3, %ymm3
    vaddpd  256(%r11,%rsi,8), %ymm0, %ymm0
    vaddpd  288(%r11,%rsi,8), %ymm1, %ymm1
    vaddpd  320(%r11,%rsi,8), %ymm2, %ymm2
    vaddpd  352(%r11,%rsi,8), %ymm3, %ymm3
    vaddpd  384(%r11,%rsi,8), %ymm0, %ymm0
    vaddpd  416(%r11,%rsi,8), %ymm1, %ymm1
    vaddpd  448(%r11,%rsi,8), %ymm2, %ymm2
    vaddpd  480(%r11,%rsi,8), %ymm3, %ymm3
    addq    $64, %rsi
    addq    $-4, %rdi
    jne .LBB0_8

在这种情况下,循环得到了很好的优化。事实上,它几乎是大型阵列的最佳选择。对于小型阵列,就像在您的示例中一样,它在像我这样的某些处理器上并不是最佳的。事实上,AFAIK,展开的指令没有使用足够的寄存器来隐藏 FMA 单元的延迟(这是因为 LLVM 在内部生成次优代码)。可能需要较低级别的本机代码来修复此问题(至少,在 Numba 中没有简单的方法来修复此问题)。

© www.soinside.com 2019 - 2024. All rights reserved.