我对 numba 使用字典时的性能感兴趣。我做了以下实验:
from numpy.random import randint
import numba as nb
@nb.njit
def foo_numba(a, b, c):
N = 100**2
d = {}
for i in range(N):
d[(randint(N), randint(N), randint(N))] = (a, b, c)
return d
@nb.njit
def test_numba(numba_dict):
s = 0
for k in numba_dict:
s += numba_dict[k][2]
return s
def foo(a, b, c):
N = 100**2
d = {}
for i in range(N):
d[(randint(N), randint(N), randint(N))] = (a, b, c)
return d
def test(numba_dict):
s = 0
for k in numba_dict:
s += numba_dict[k][2]
return s
a = randint(10, size=10)
b = randint(10, size=10)
c = 1.3
t_numba = foo_numba(a, b, c)
dummy = test_numba(t_numba)
%timeit test_numba(t_numba)
t = foo(a, b, c)
%timeit test(t)
令我惊讶的是,我得到的输出是:
870 µs ± 6.36 µs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
654 µs ± 35.8 µs per loop (mean ± std. dev. of 10 runs, 1,000 loops each)
为什么 numba 代码比 cPython 代码慢这么多?我是否错误地使用了 numba?
如果您执行 a = tuple(a) 或 b = tuple(b) (似乎没有必要同时执行这两个操作!),那么速度减慢就会完全消失。
正如 Jerome 在评论中提到的,这里最大的问题是您将 JIT 编译的时间作为基准测试的一部分。
只需在基准测试开始之前添加
test_numba(foo_numba(a, b, c))
将确保在测试之前编译内容,并消除大约一半的差异。
除此之外,您的测试用例并不是特别有用。 Python 的
numba
s 已经进行了深度优化,C 层代码无法做任何有意义的事情来进一步优化访问,这与 dict
s/list
s 不同(其中 C 层代码可以直接获取原始存储)内存访问避免了在 Python 级别执行相同工作时涉及的多层内部函数调用)或 tuple
s/int
s(通过直接提取原始 C 值并执行 C 级别数学可以再次更有效地处理虽然现代 CPython 特殊情况下使用字节码解释器中的简单数学来绕过大部分开销,但开销较少)。理论上,您使用float
数组
可以从
numpy
中受益,但是您使用它们的方式意味着这些好处并不存在。具体来说,您的测试功能是:(不涉及实际的
numba
numpy
加载
tuple
的
int
(如果您只要求一个值,
dict
的
numpy
将返回 Python
randint
,而不是
int
标量);
numpy
迭代已经是重度优化了,
dict
没办法,单纯加载
numba
也不是可以进一步优化的事情。
(不涉及
tuple
numpy
中所说的
tuple
(再次,
dict
无能为力)
从生成的
numba
numpy
可以直接访问它,Python 需要构造一个包装对象,但这在 Python 中进行了大量优化,并且
numba
的直接访问收益非常大)当它只是一个值时很少;无法进行矢量化,第一次访问的成本差异不够大,并且您不会再次访问同一数组,因此不会产生大量工作收益)
将其添加到
numba
s
可以
提供帮助的一件事,如果它认识到
numba
是一个永远不会溢出 64 位值的 s
。但同样,这是最现代版本的 Python 已经优化了很多的东西(到 3.12,字节码会自我修改以处理
int
总是涉及
+=
的情况,并绕过旧 Python 的大部分开销)必须付费)。简而言之:
int
不会为
numba
s 做太多事情
当您进行大量不间断的简单数学运算时,dict
numba
数组进行批量数学运算。从
numpy
数组中的 dict
中提取单个值无法克服
numpy
在尝试
优化困难事物时所付出的成本,这些成本被浪费在简单的事情上。