为什么添加break语句会显着减慢Numba函数的速度?

问题描述 投票:0回答:1

我有以下 Numba 功能:

@numba.njit
def count_in_range(arr, min_value, max_value):
    count = 0
    for a in arr:
        if min_value < a < max_value:
            count += 1
    return count

它计算数组中有多少个值在该范围内。

但是,我意识到我只需要确定它们是否存在即可。 所以我修改如下:

@numba.njit
def count_in_range2(arr, min_value, max_value):
    count = 0
    for a in arr:
        if min_value < a < max_value:
            count += 1
            break  # <---- break here
    return count

然后,该功能变得比更改之前。 在某些条件下,速度可能会慢 10 倍以上。

基准代码:

from timeit import timeit

rng = np.random.default_rng(0)
arr = rng.random(10 * 1000 * 1000)

# To compare on even conditions, choose the condition that does not terminate early.
min_value = 0.5
max_value = min_value - 1e-10
assert not np.any(np.logical_and(min_value <= arr, arr <= max_value))

n = 100
for f in (count_in_range, count_in_range2):
    f(arr, min_value, max_value)
    elapsed = timeit(lambda: f(arr, min_value, max_value), number=n) / n
    print(f"{f.__name__}: {elapsed * 1000:.3f} ms")

结果:

count_in_range: 3.351 ms
count_in_range2: 42.312 ms

进一步实验,我发现速度根据搜索范围(即

min_value
max_value
)变化很大。

在不同的搜索范围:

count_in_range2: 5.802 ms, range: (0.0, -1e-10)
count_in_range2: 15.408 ms, range: (0.1, 0.09999999990000001)
count_in_range2: 29.571 ms, range: (0.25, 0.2499999999)
count_in_range2: 42.514 ms, range: (0.5, 0.4999999999)
count_in_range2: 24.427 ms, range: (0.75, 0.7499999999)
count_in_range2: 12.547 ms, range: (0.9, 0.8999999999)
count_in_range2: 5.747 ms, range: (1.0, 0.9999999999)

有人可以向我解释一下发生了什么事吗?


我在 Python 3.10.11 下使用 Numba 0.58.1。 在 Windows 10 和 Ubuntu 22.04 上均得到确认。


编辑:

作为 Jérôme Richard 答案的附录:

正如他在评论中指出的那样,取决于搜索范围的性能差异可能是由于分支预测造成的。

例如,当

min_value
0.1
时,
min_value < a
有 90% 的机会为真,
a < max_value
有 90% 的机会为假。所以从数学上来说它可以被正确预测,准确率达到 81%。我不知道CPU是如何做到这一点的,但我想出了一种方法来检查这个逻辑是否正确。

首先,用高于和低于阈值的值对数组进行分区,其次,以一定的错误概率将其混合。当数组被分区时,分支预测未命中的数量应该不受阈值的影响。当我们在其中包含错误时,未命中的数量应该根据错误而增加。

这是更新后的基准代码:

from timeit import timeit
import numba
import numpy as np


@numba.njit
def count_in_range(arr, min_value, max_value):
    count = 0
    for a in arr:
        if min_value < a < max_value:
            count += 1
    return count


@numba.njit
def count_in_range2(arr, min_value, max_value):
    count = 0
    for a in arr:
        if min_value < a < max_value:
            count += 1
            break  # <---- break here
    return count


def partition(arr, threshold):
    """Place the elements smaller than the threshold in the front and the elements larger than the threshold in the back."""
    less = arr[arr < threshold]
    more = arr[~(arr < threshold)]
    return np.concatenate((less, more))


def partition_with_error(arr, threshold, error_rate):
    """Same as partition, but includes errors with a certain probability."""
    less = arr[arr < threshold]
    more = arr[~(arr < threshold)]
    less_error, less_correct = np.split(less, [int(len(less) * error_rate)])
    more_error, more_correct = np.split(more, [int(len(more) * error_rate)])
    mostly_less = np.concatenate((less_correct, more_error))
    mostly_more = np.concatenate((more_correct, less_error))
    rng = np.random.default_rng(0)
    rng.shuffle(mostly_less)
    rng.shuffle(mostly_more)
    out = np.concatenate((mostly_less, mostly_more))
    assert np.array_equal(np.sort(out), np.sort(arr))
    return out


def bench(f, arr, min_value, max_value, n=10, info=""):
    f(arr, min_value, max_value)
    elapsed = timeit(lambda: f(arr, min_value, max_value), number=n) / n
    print(f"{f.__name__}: {elapsed * 1000:.3f} ms, min_value: {min_value:.1f}, {info}")


def main():
    rng = np.random.default_rng(0)
    arr = rng.random(10 * 1000 * 1000)
    thresholds = np.linspace(0, 1, 11)

    print("#", "-" * 10, "As for comparison", "-" * 10)
    bench(
        count_in_range,
        arr,
        min_value=0.5,
        max_value=0.5 - 1e-10,
    )

    print("\n#", "-" * 10, "Random Data", "-" * 10)
    for min_value in thresholds:
        bench(
            count_in_range2,
            arr,
            min_value=min_value,
            max_value=min_value - 1e-10,
        )

    print("\n#", "-" * 10, "Partitioned (Yet Still Random) Data", "-" * 10)
    for min_value in thresholds:
        bench(
            count_in_range2,
            partition(arr, threshold=min_value),
            min_value=min_value,
            max_value=min_value - 1e-10,
        )

    print("\n#", "-" * 10, "Partitioned Data with Probabilistic Errors", "-" * 10)
    for ratio in thresholds:
        bench(
            count_in_range2,
            partition_with_error(arr, threshold=0.5, error_rate=ratio),
            min_value=0.5,
            max_value=0.5 - 1e-10,
            info=f"error: {ratio:.0%}",
        )


if __name__ == "__main__":
    main()

结果:

# ---------- As for comparison ----------
count_in_range: 3.518 ms, min_value: 0.5, 

# ---------- Random Data ----------
count_in_range2: 5.958 ms, min_value: 0.0, 
count_in_range2: 15.390 ms, min_value: 0.1, 
count_in_range2: 24.715 ms, min_value: 0.2, 
count_in_range2: 33.749 ms, min_value: 0.3, 
count_in_range2: 40.007 ms, min_value: 0.4, 
count_in_range2: 42.168 ms, min_value: 0.5, 
count_in_range2: 37.427 ms, min_value: 0.6, 
count_in_range2: 28.763 ms, min_value: 0.7, 
count_in_range2: 20.089 ms, min_value: 0.8, 
count_in_range2: 12.638 ms, min_value: 0.9, 
count_in_range2: 5.876 ms, min_value: 1.0, 

# ---------- Partitioned (Yet Still Random) Data ----------
count_in_range2: 6.006 ms, min_value: 0.0, 
count_in_range2: 5.999 ms, min_value: 0.1, 
count_in_range2: 5.953 ms, min_value: 0.2, 
count_in_range2: 5.952 ms, min_value: 0.3, 
count_in_range2: 5.940 ms, min_value: 0.4, 
count_in_range2: 6.870 ms, min_value: 0.5, 
count_in_range2: 5.939 ms, min_value: 0.6, 
count_in_range2: 5.896 ms, min_value: 0.7, 
count_in_range2: 5.899 ms, min_value: 0.8, 
count_in_range2: 5.880 ms, min_value: 0.9, 
count_in_range2: 5.884 ms, min_value: 1.0, 

# ---------- Partitioned Data with Probabilistic Errors ----------
count_in_range2: 5.939 ms, min_value: 0.5, error: 0%
count_in_range2: 14.015 ms, min_value: 0.5, error: 10%
count_in_range2: 22.599 ms, min_value: 0.5, error: 20%
count_in_range2: 31.763 ms, min_value: 0.5, error: 30%
count_in_range2: 39.391 ms, min_value: 0.5, error: 40%
count_in_range2: 42.227 ms, min_value: 0.5, error: 50%
count_in_range2: 38.748 ms, min_value: 0.5, error: 60%
count_in_range2: 31.758 ms, min_value: 0.5, error: 70%
count_in_range2: 22.600 ms, min_value: 0.5, error: 80%
count_in_range2: 14.090 ms, min_value: 0.5, error: 90%
count_in_range2: 6.027 ms, min_value: 0.5, error: 100%

我对这个结果很满意。

python performance numba
1个回答
6
投票

TL;DR:Numba 使用 LLVM,当存在

break
时,它无法自动矢量化代码。解决此问题的一种方法是逐块计算操作。


Numba 基于 LLVM 编译器工具链,将 Python 代码编译为本机代码。 Numlba 从 Python 代码生成 LLVM 中间表示 (IR),然后将其提供给 LLVM,以便它可以生成快速的本机代码。所有低级优化都是由 LLVM 完成的,实际上并不是 Numba 本身。在这种情况下,当存在 break

 时,LLVM 无法自动矢量化代码。 Numba 在这里不进行任何模式识别,也不在 GPU 上运行任何代码(基本 
numba.njit
代码始终在 CPU 上运行)。

请注意,本文中的“矢量化”意味着从标量 IR 代码生成 SIMD 指令。这个词在 Numpy Python 代码的上下文中具有不同的含义(这意味着调用本机函数以减少开销,但本机函数不一定使用 SIMD 指令)。


在引擎盖下

我用 Clang 重现了这个问题,Clang 是一个 C++ 编译器,也使用 LLVM 工具链。这是等效的 C++ 代码:

#include <cstdint>
#include <cstdlib>
#include <vector>

int64_t count_in_range(const std::vector<double>& arr, double min_value, double max_value)
{
    int64_t count = 0;

    for(int64_t i=0 ; i<arr.size() ; ++i)
    {
        double a = arr[i];

        if (min_value < a && a < max_value)
        {
            count += 1;
        }
    }

    return count;
}

此代码结果在以下汇编主循环中:

.LBB0_6:                                # =>This Inner Loop Header: Depth=1
        vmovupd ymm8, ymmword ptr [rcx + 8*rax]
        vmovupd ymm9, ymmword ptr [rcx + 8*rax + 32]
        vmovupd ymm10, ymmword ptr [rcx + 8*rax + 64]
        vmovupd ymm11, ymmword ptr [rcx + 8*rax + 96]
        vcmpltpd        ymm12, ymm2, ymm8
        vcmpltpd        ymm13, ymm2, ymm9
        vcmpltpd        ymm14, ymm2, ymm10
        vcmpltpd        ymm15, ymm2, ymm11
        vcmpltpd        ymm8, ymm8, ymm4
        vandpd  ymm8, ymm12, ymm8
        vpsubq  ymm3, ymm3, ymm8
        vcmpltpd        ymm8, ymm9, ymm4
        vandpd  ymm8, ymm13, ymm8
        vpsubq  ymm5, ymm5, ymm8
        vcmpltpd        ymm8, ymm10, ymm4
        vandpd  ymm8, ymm14, ymm8
        vpsubq  ymm6, ymm6, ymm8
        vcmpltpd        ymm8, ymm11, ymm4
        vandpd  ymm8, ymm15, ymm8
        vpsubq  ymm7, ymm7, ymm8
        add     rax, 16
        cmp     rsi, rax
        jne     .LBB0_6

指令

vmovupd
vcmpltpd
vandpd
等说明汇编代码完全使用了SIMD指令

如果我们添加一个

break
,那么情况就不再是这样了:

.LBB0_4:                                # =>This Inner Loop Header: Depth=1
        vmovsd  xmm2, qword ptr [rcx + 8*rsi]   # xmm2 = mem[0],zero
        vcmpltpd        xmm3, xmm2, xmm1
        vcmpltpd        xmm2, xmm0, xmm2
        vandpd  xmm2, xmm2, xmm3
        vmovq   rdi, xmm2
        sub     rax, rdi
        test    dil, 1
        jne     .LBB0_2
        lea     rdi, [rsi + 1]
        cmp     rdx, rsi
        mov     rsi, rdi
        jne     .LBB0_4

此处

vmovsd
在循环中移动标量值(并且
rsi
每次循环迭代都会增加 1)。后面的代码效率明显较低。事实上,每次迭代一次仅对一项进行操作,而不是之前的代码对 16 项进行操作。

我们可以使用编译标志

-Rpass-missed=loop-vectorize
来检查循环是否确实没有矢量化。 Clang 明确报告:

备注:循环未矢量化 [-Rpass-missed=loop-vectorize

要知道原因,我们可以使用标志

-Rpass-analysis=loop-vectorize
:

循环未矢量化:无法确定循环迭代次数[-Rpass-analysis =循环向量化]

因此,我们可以得出结论,LLVM 优化器不支持这种代码模式。


解决方案

避免此问题的一种方法是对块进行操作。每个块的计算可以通过 Clang 完全矢量化,并且您可以在第一个块处尽早打破条件。

这是未经测试的代码:

@numba.njit
def count_in_range_faster(arr, min_value, max_value):
    count = 0
    for i in range(0, arr.size, 16):
        if arr.size - i >= 16:
            # Optimized SIMD-friendly computation of 1 chunk of size 16
            tmp_view = arr[i:i+16]
            for j in range(0, 16):
                if min_value < tmp_view[j] < max_value:
                    count += 1
            if count > 0:
                return 1
        else:
            # Fallback implementation (variable-sized chunk)
            for j in range(i, arr.size):
                if min_value < arr[j] < max_value:
                    count += 1
            if count > 0:
                return 1
    return 0

C++ 等效代码已正确矢量化。需要使用

count_in_range_faster.inspect_llvm()
检查 Numba 代码是否也是如此,但以下计时表明上述实现比其他两个实现更快。


性能结果

以下是使用 Numba 0.56.0 的 Xeon W-2255 CPU 机器上的结果:

count_in_range:          7.112 ms
count_in_range2:        35.317 ms
count_in_range_faster:   5.827 ms     <----------
© www.soinside.com 2019 - 2024. All rights reserved.