为什么 numba 在使用字典时慢 25%

问题描述 投票:0回答:1

我对 numba 使用字典时的性能感兴趣。我做了以下实验:

from numpy.random import randint
import numba as nb

@nb.njit
def foo_numba(a, b, c):
    N = 100**2
    d = {}
    for i in range(N):
        d[(randint(N), randint(N), randint(N))] = (a, b, c)
    return d


@nb.njit
def test_numba(numba_dict):
    s = 0
    for k in numba_dict:
        s += numba_dict[k][2]
    return s



def foo(a, b, c):
    N = 100**2
    d = {}
    for i in range(N):
        d[(randint(N), randint(N), randint(N))] = (a, b, c)
    return d



def test(numba_dict):
    s = 0
    for k in numba_dict:
        s += numba_dict[k][2]
    return s

a = randint(10, size=10)
b = randint(10, size=10)
c = 1.3

t_numba = foo_numba(a, b, c)
dummy = test_numba(t_numba)
%timeit test_numba(t_numba)
t = foo(a, b, c)
%timeit test(t) 

令我惊讶的是,我得到的输出是:

870 µs ± 6.36 µs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
654 µs ± 35.8 µs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)

为什么 numba 代码比 cPython 代码慢这么多?我是否错误地使用了 numba?

如果您执行 a = tuple(a) 或 b = tuple(b) (似乎没有必要同时执行这两个操作!),那么速度减慢就会完全消失。

python numba
1个回答
3
投票

正如 Jerome 在评论中提到的,这里最大的问题是您将 JIT 编译的时间作为基准测试的一部分。

只需在基准测试开始之前添加

test_numba(foo_numba(a, b, c))
将确保在测试之前编译内容,并消除大约一半的差异。

除此之外,您的测试用例并不是特别有用。 Python 的

numba
s 已经进行了深度优化,C 层代码无法做任何有意义的事情来进一步优化访问,这与
dict
s/
list
s 不同(其中 C 层代码可以直接获取原始存储)内存访问避免了在 Python 级别执行相同工作时涉及的多层内部函数调用)或
tuple
s/
int
s(通过直接提取原始 C 值并执行 C 级别数学可以再次更有效地处理虽然现代 CPython 特殊情况下使用字节码解释器中的简单数学来绕过大部分开销,但开销较少)。
理论上,您使用

float

数组

可以从
numpy中受益,但是您使用它们的方式
意味着这些好处并不存在。具体来说,您的测试功能是:
(不涉及实际的

numba
    类型)通过迭代
  1. numpy
     加载 
    tuple
    int
    (如果您只要求一个值,
    dict
    numpy
     将返回 Python 
    randint
    ,而不是
    int
     标量); 
    numpy
    迭代已经是重度优化了,
    dict
    没办法,单纯加载
    numba
    也不是可以进一步优化的事情。
    (不涉及
    tuple
  2. 类型)再次查找
  3. numpy
    中所说的
    tuple
    (再次,
    dict
    无能为力)
    从生成的 
    numba
  4. 数组中提取
  5. single 值(numpy
     可以直接访问它,Python 需要构造一个包装对象,但这在 Python 中进行了大量优化,并且 
    numba
     的直接访问收益非常大)当它只是一个值时很少;无法进行矢量化,第一次访问的成本差异不够大,并且您不会再次访问同一数组,因此不会产生大量工作收益)
    将其添加到
    numba
  6. 。这是
  7. s
     
    可以
     提供帮助的一件事,如果它认识到 
    numba 是一个永远不会溢出 64 位值的 s
    。但同样,这是最现代版本的 Python 已经优化了很多的东西(到 3.12,字节码会自我修改以处理 
    int
     总是涉及 
    +=
     的情况,并绕过旧 Python 的大部分开销)必须付费)。
    简而言之:

int
  1. 永远不会
     不会为 
    numbas 做太多事情 当您进行大量不间断的简单数学运算时,
    dict
  2. 的增益效果最佳,理想情况下是使用
  3. numba
     数组进行
    批量数学运算。从 numpy 数组中的 dict
     中提取单个值无法克服 
    numpy
    尝试
     优化困难事物时所付出的成本,这些成本被浪费在简单的事情上。
        
© www.soinside.com 2019 - 2024. All rights reserved.