Numba 对于简单的 for 循环性能不佳(Python 3.10)

问题描述 投票:0回答:1

我的代码:


from numba import njit
from functools import wraps
import time

def timeit(my_func):
    @wraps(my_func)
    def timed(*args, **kw):
    
        tstart = time.time()
        output = my_func(*args, **kw)
        tend = time.time()
        
        print('"{}" took {:.3f} ms to execute\n'.format(my_func.__name__, (tend - tstart) * 1000))
        return output
    return timed

@timeit
@njit
def calculate_smth(a,b):
    result = 0
    for i_a in range(a):
        for i_b in range(b):
            result = result + i_a + i_b
    return result

if __name__ == "__main__":
    value = calculate_smth(1000,1000)

没有 numba 装饰器,我的函数在 ~62ms 内完成,使用 njit 装饰器(预先编译后)需要 ~370ms。 有人可以解释我缺少什么吗?

python numba
1个回答
1
投票

JIT 代表即时 - 意味着代码在执行时编译 - 与 AOT - 提前编译相反。正如您可以在 Numba docs 中阅读的那样,默认情况下编译是惰性的,即。它发生在程序中的第一个函数执行时。 它还支持 AOT 编译,如

这里

所述 另一个选项是将

cache=True

参数传递给

numba.njit
装饰器。
作为一个具体示例,编辑代码以包含编译函数的虚拟调用,我们可以看到执行该函数实际上只需要很少的时间甚至根本不需要时间:

... if __name__ == "__main__": calculate_smth(1,1) value = calculate_smth(1000, 1000)

输出:

"calculate_smth" took 590.578 ms to execute "calculate_smth" took 0.000 ms to execute

© www.soinside.com 2019 - 2024. All rights reserved.