我正在使用 scipy 的双重积分
dblquad
并且我正在尝试提高速度。我已经检查了网上提出的解决方案,但无法使它们发挥作用。
为了缓解这个问题,我准备了下面的比较。我做错了什么或者我可以做些什么来提高速度?
from scipy import integrate
import timeit
from numba import njit, jit
def bb_pure(q, z, x_loc, y_loc, B, L):
def f(y, x):
return (
3
* q
* z ** 3
/ (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
)
)
return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]
def bb_numbanjit(q, z, x_loc, y_loc, B, L):
@njit
def f(y, x):
return (
3
* q
* z ** 3
/ (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
)
)
return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]
def bb_numbajit(q, z, x_loc, y_loc, B, L):
@jit
def f(y, x):
return (
3
* q
* z ** 3
/ (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
)
)
return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]
####
starttime = timeit.default_timer()
for i in range(100):
bb_pure(200, 5, 0, 0, i, i*2)
print("Pure Function:", round(timeit.default_timer() - starttime,2))
####
starttime = timeit.default_timer()
for i in range(100):
bb_numbanjit(200, 5, 0, 0, i, i*2)
print("Numba njit:", round(timeit.default_timer() - starttime,2))
####
starttime = timeit.default_timer()
for i in range(100):
bb_numbajit(200, 5, 0, 0, i, i*2)
print("Numba jit:", round(timeit.default_timer() - starttime,2))
结果
Pure Function: 3.22
Numba njit: 8.14
Numba jit: 8.15
主要问题是您正在计时 Numba 函数的编译时间。事实上,当调用
bb_numbanjit
时,@njit
装饰器告诉 Numba 声明一个 惰性编译函数,该函数在执行 第一次调用 时进行编译,因此在 integrate.dblquad
中。完全相同的行为也适用于 bb_numbajit
。 Numba 实现速度较慢,因为编译时间与执行时间相比相当长。问题是 Numba 函数是“闭包”,它读取需要新编译的本地参数。解决这个问题的典型方法是向 Numba 函数添加新参数并编译一次。由于这里需要一个闭包,因此可以使用代理闭包。这是一个例子:
@njit
def f_numba(y, x, q, z, x_loc, y_loc, B, L):
return (
3
* q
* z ** 3
/ (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
)
)
def bb_numbanjit(q, z, x_loc, y_loc, B, L):
def f_proxy(y, x):
return f_numba(y, x, q, z, x_loc, y_loc, B, L)
return integrate.dblquad(f_proxy, 0, B, lambda x: 0, lambda x: L)[0]
这比
bb_pure
解决方案快两倍。
这个 Numba 解决方案速度并不快的一个原因是 Python 函数调用成本很高,尤其是当有很多参数时。另一个问题是,某些参数似乎是常量,而 Numba 不知道这一点,因为它们是作为运行时参数而不是编译时常量传递的。您可以将全局变量中的常量移动到,让 Numba 进一步优化代码(通过预先计算常量子表达式)。另请注意,Numba 函数已在内部由代理函数包装。对于这种基本的数值运算,代理函数有点昂贵(它们执行一些类型检查和纯 Python 对象到本机值的转换)。话虽这么说,由于关闭问题,这里没什么可做的。