我有以下 Numba 功能:
@numba.njit
def count_in_range(arr, min_value, max_value):
count = 0
for a in arr:
if min_value < a < max_value:
count += 1
return count
它计算数组中有多少个值在该范围内。
但是,我意识到我只需要确定它们是否存在即可。 所以我修改如下:
@numba.njit
def count_in_range2(arr, min_value, max_value):
count = 0
for a in arr:
if min_value < a < max_value:
count += 1
break # <---- break here
return count
然后,该功能变得比更改之前慢。 在某些条件下,速度可能会慢 10 倍以上。
基准代码:
from timeit import timeit
rng = np.random.default_rng(0)
arr = rng.random(10 * 1000 * 1000)
# To compare on even conditions, choose the condition that does not terminate early.
min_value = 0.5
max_value = min_value - 1e-10
assert not np.any(np.logical_and(min_value <= arr, arr <= max_value))
n = 100
for f in (count_in_range, count_in_range2):
f(arr, min_value, max_value)
elapsed = timeit(lambda: f(arr, min_value, max_value), number=n) / n
print(f"{f.__name__}: {elapsed * 1000:.3f} ms")
结果:
count_in_range: 3.351 ms
count_in_range2: 42.312 ms
进一步实验,我发现速度根据搜索范围(即
min_value
和max_value
)变化很大。
在不同的搜索范围:
count_in_range2: 5.802 ms, range: (0.0, -1e-10)
count_in_range2: 15.408 ms, range: (0.1, 0.09999999990000001)
count_in_range2: 29.571 ms, range: (0.25, 0.2499999999)
count_in_range2: 42.514 ms, range: (0.5, 0.4999999999)
count_in_range2: 24.427 ms, range: (0.75, 0.7499999999)
count_in_range2: 12.547 ms, range: (0.9, 0.8999999999)
count_in_range2: 5.747 ms, range: (1.0, 0.9999999999)
有人可以向我解释一下发生了什么事吗?
我在 Python 3.10.11 下使用 Numba 0.58.1。 在 Windows 10 和 Ubuntu 22.04 上均得到确认。
作为 Jérôme Richard 答案的附录:
正如他在评论中指出的那样,取决于搜索范围的性能差异可能是由于分支预测造成的。
例如,当
min_value
为 0.1
时,min_value < a
有 90% 的机会为真,a < max_value
有 90% 的机会为假。所以从数学上来说它可以被正确预测,准确率达到 81%。我不知道CPU是如何做到这一点的,但我想出了一种方法来检查这个逻辑是否正确。
首先,用高于和低于阈值的值对数组进行分区,其次,以一定的错误概率将其混合。当数组被分区时,分支预测未命中的数量应该不受阈值的影响。当我们在其中包含错误时,未命中的数量应该根据错误而增加。
这是更新后的基准代码:
from timeit import timeit
import numba
import numpy as np
@numba.njit
def count_in_range(arr, min_value, max_value):
count = 0
for a in arr:
if min_value < a < max_value:
count += 1
return count
@numba.njit
def count_in_range2(arr, min_value, max_value):
count = 0
for a in arr:
if min_value < a < max_value:
count += 1
break # <---- break here
return count
def partition(arr, threshold):
"""Place the elements smaller than the threshold in the front and the elements larger than the threshold in the back."""
less = arr[arr < threshold]
more = arr[~(arr < threshold)]
return np.concatenate((less, more))
def partition_with_error(arr, threshold, error_rate):
"""Same as partition, but includes errors with a certain probability."""
less = arr[arr < threshold]
more = arr[~(arr < threshold)]
less_error, less_correct = np.split(less, [int(len(less) * error_rate)])
more_error, more_correct = np.split(more, [int(len(more) * error_rate)])
mostly_less = np.concatenate((less_correct, more_error))
mostly_more = np.concatenate((more_correct, less_error))
rng = np.random.default_rng(0)
rng.shuffle(mostly_less)
rng.shuffle(mostly_more)
out = np.concatenate((mostly_less, mostly_more))
assert np.array_equal(np.sort(out), np.sort(arr))
return out
def bench(f, arr, min_value, max_value, n=10, info=""):
f(arr, min_value, max_value)
elapsed = timeit(lambda: f(arr, min_value, max_value), number=n) / n
print(f"{f.__name__}: {elapsed * 1000:.3f} ms, min_value: {min_value:.1f}, {info}")
def main():
rng = np.random.default_rng(0)
arr = rng.random(10 * 1000 * 1000)
thresholds = np.linspace(0, 1, 11)
print("#", "-" * 10, "As for comparison", "-" * 10)
bench(
count_in_range,
arr,
min_value=0.5,
max_value=0.5 - 1e-10,
)
print("\n#", "-" * 10, "Random Data", "-" * 10)
for min_value in thresholds:
bench(
count_in_range2,
arr,
min_value=min_value,
max_value=min_value - 1e-10,
)
print("\n#", "-" * 10, "Partitioned (Yet Still Random) Data", "-" * 10)
for min_value in thresholds:
bench(
count_in_range2,
partition(arr, threshold=min_value),
min_value=min_value,
max_value=min_value - 1e-10,
)
print("\n#", "-" * 10, "Partitioned Data with Probabilistic Errors", "-" * 10)
for ratio in thresholds:
bench(
count_in_range2,
partition_with_error(arr, threshold=0.5, error_rate=ratio),
min_value=0.5,
max_value=0.5 - 1e-10,
info=f"error: {ratio:.0%}",
)
if __name__ == "__main__":
main()
结果:
# ---------- As for comparison ----------
count_in_range: 3.518 ms, min_value: 0.5,
# ---------- Random Data ----------
count_in_range2: 5.958 ms, min_value: 0.0,
count_in_range2: 15.390 ms, min_value: 0.1,
count_in_range2: 24.715 ms, min_value: 0.2,
count_in_range2: 33.749 ms, min_value: 0.3,
count_in_range2: 40.007 ms, min_value: 0.4,
count_in_range2: 42.168 ms, min_value: 0.5,
count_in_range2: 37.427 ms, min_value: 0.6,
count_in_range2: 28.763 ms, min_value: 0.7,
count_in_range2: 20.089 ms, min_value: 0.8,
count_in_range2: 12.638 ms, min_value: 0.9,
count_in_range2: 5.876 ms, min_value: 1.0,
# ---------- Partitioned (Yet Still Random) Data ----------
count_in_range2: 6.006 ms, min_value: 0.0,
count_in_range2: 5.999 ms, min_value: 0.1,
count_in_range2: 5.953 ms, min_value: 0.2,
count_in_range2: 5.952 ms, min_value: 0.3,
count_in_range2: 5.940 ms, min_value: 0.4,
count_in_range2: 6.870 ms, min_value: 0.5,
count_in_range2: 5.939 ms, min_value: 0.6,
count_in_range2: 5.896 ms, min_value: 0.7,
count_in_range2: 5.899 ms, min_value: 0.8,
count_in_range2: 5.880 ms, min_value: 0.9,
count_in_range2: 5.884 ms, min_value: 1.0,
# ---------- Partitioned Data with Probabilistic Errors ----------
count_in_range2: 5.939 ms, min_value: 0.5, error: 0%
count_in_range2: 14.015 ms, min_value: 0.5, error: 10%
count_in_range2: 22.599 ms, min_value: 0.5, error: 20%
count_in_range2: 31.763 ms, min_value: 0.5, error: 30%
count_in_range2: 39.391 ms, min_value: 0.5, error: 40%
count_in_range2: 42.227 ms, min_value: 0.5, error: 50%
count_in_range2: 38.748 ms, min_value: 0.5, error: 60%
count_in_range2: 31.758 ms, min_value: 0.5, error: 70%
count_in_range2: 22.600 ms, min_value: 0.5, error: 80%
count_in_range2: 14.090 ms, min_value: 0.5, error: 90%
count_in_range2: 6.027 ms, min_value: 0.5, error: 100%
我对这个结果很满意。
TL;DR:Numba 使用 LLVM,当存在
break
时,它无法自动矢量化代码。解决此问题的一种方法是逐块计算操作。
Numba 基于 LLVM 编译器工具链,将 Python 代码编译为本机代码。 Numlba 从 Python 代码生成 LLVM 中间表示 (IR),然后将其提供给 LLVM,以便它可以生成快速的本机代码。所有低级优化都是由 LLVM 完成的,实际上并不是 Numba 本身。在这种情况下,当存在 break
numba.njit
代码始终在 CPU 上运行)。
请注意,本文中的“矢量化”意味着从标量 IR 代码生成 SIMD 指令。这个词在 Numpy Python 代码的上下文中具有不同的含义(这意味着调用本机函数以减少开销,但本机函数不一定使用 SIMD 指令)。
我用 Clang 重现了这个问题,Clang 是一个 C++ 编译器,也使用 LLVM 工具链。这是等效的 C++ 代码:
#include <cstdint>
#include <cstdlib>
#include <vector>
int64_t count_in_range(const std::vector<double>& arr, double min_value, double max_value)
{
int64_t count = 0;
for(int64_t i=0 ; i<arr.size() ; ++i)
{
double a = arr[i];
if (min_value < a && a < max_value)
{
count += 1;
}
}
return count;
}
此代码结果在以下汇编主循环中:
.LBB0_6: # =>This Inner Loop Header: Depth=1
vmovupd ymm8, ymmword ptr [rcx + 8*rax]
vmovupd ymm9, ymmword ptr [rcx + 8*rax + 32]
vmovupd ymm10, ymmword ptr [rcx + 8*rax + 64]
vmovupd ymm11, ymmword ptr [rcx + 8*rax + 96]
vcmpltpd ymm12, ymm2, ymm8
vcmpltpd ymm13, ymm2, ymm9
vcmpltpd ymm14, ymm2, ymm10
vcmpltpd ymm15, ymm2, ymm11
vcmpltpd ymm8, ymm8, ymm4
vandpd ymm8, ymm12, ymm8
vpsubq ymm3, ymm3, ymm8
vcmpltpd ymm8, ymm9, ymm4
vandpd ymm8, ymm13, ymm8
vpsubq ymm5, ymm5, ymm8
vcmpltpd ymm8, ymm10, ymm4
vandpd ymm8, ymm14, ymm8
vpsubq ymm6, ymm6, ymm8
vcmpltpd ymm8, ymm11, ymm4
vandpd ymm8, ymm15, ymm8
vpsubq ymm7, ymm7, ymm8
add rax, 16
cmp rsi, rax
jne .LBB0_6
指令
vmovupd
、vcmpltpd
和vandpd
等说明汇编代码完全使用了SIMD指令。
如果我们添加一个
break
,那么情况就不再是这样了:
.LBB0_4: # =>This Inner Loop Header: Depth=1
vmovsd xmm2, qword ptr [rcx + 8*rsi] # xmm2 = mem[0],zero
vcmpltpd xmm3, xmm2, xmm1
vcmpltpd xmm2, xmm0, xmm2
vandpd xmm2, xmm2, xmm3
vmovq rdi, xmm2
sub rax, rdi
test dil, 1
jne .LBB0_2
lea rdi, [rsi + 1]
cmp rdx, rsi
mov rsi, rdi
jne .LBB0_4
此处
vmovsd
在循环中移动标量值(并且 rsi
每次循环迭代都会增加 1)。后面的代码效率明显较低。事实上,每次迭代一次仅对一项进行操作,而不是之前的代码对 16 项进行操作。
我们可以使用编译标志
-Rpass-missed=loop-vectorize
来检查循环是否确实没有矢量化。 Clang 明确报告:
备注:循环未矢量化 [-Rpass-missed=loop-vectorize
要知道原因,我们可以使用标志
-Rpass-analysis=loop-vectorize
:
循环未矢量化:无法确定循环迭代次数[-Rpass-analysis =循环向量化]
因此,我们可以得出结论,LLVM 优化器不支持这种代码模式。
避免此问题的一种方法是对块进行操作。每个块的计算可以通过 Clang 完全矢量化,并且您可以在第一个块处尽早打破条件。
这是未经测试的代码:
@numba.njit
def count_in_range_faster(arr, min_value, max_value):
count = 0
for i in range(0, arr.size, 16):
if arr.size - i >= 16:
# Optimized SIMD-friendly computation of 1 chunk of size 16
tmp_view = arr[i:i+16]
for j in range(0, 16):
if min_value < tmp_view[j] < max_value:
count += 1
if count > 0:
return 1
else:
# Fallback implementation (variable-sized chunk)
for j in range(i, arr.size):
if min_value < arr[j] < max_value:
count += 1
if count > 0:
return 1
return 0
C++ 等效代码已正确矢量化。需要使用
count_in_range_faster.inspect_llvm()
检查 Numba 代码是否也是如此,但以下计时表明上述实现比其他两个实现更快。
以下是使用 Numba 0.56.0 的 Xeon W-2255 CPU 机器上的结果:
count_in_range: 7.112 ms
count_in_range2: 35.317 ms
count_in_range_faster: 5.827 ms <----------