最快的双积分方法

问题描述 投票:0回答:1

我正在使用 scipy 的双重积分

dblquad
并且我正在尝试提高速度。我已经检查了网上提出的解决方案,但无法使它们发挥作用。 为了缓解这个问题,我准备了下面的比较。我做错了什么或者我可以做些什么来提高速度?

from scipy import integrate
import timeit
from numba import njit, jit

def bb_pure(q, z, x_loc, y_loc, B, L):

    def f(y, x):
        return (
            3
            * q
            * z ** 3
            / (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
            )
        )

    return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]

def bb_numbanjit(q, z, x_loc, y_loc, B, L):
    
    @njit
    def f(y, x):
        return (
            3
            * q
            * z ** 3
            / (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
            )
        )

    return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]

def bb_numbajit(q, z, x_loc, y_loc, B, L):
    
    @jit
    def f(y, x):
        return (
            3
            * q
            * z ** 3
            / (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5
            )
        )

    return integrate.dblquad(f, 0, B, lambda x: 0, lambda x: L)[0]

####

starttime = timeit.default_timer()
for i in range(100):
    bb_pure(200, 5, 0, 0, i, i*2)

print("Pure Function:", round(timeit.default_timer() - starttime,2))

####

starttime = timeit.default_timer()
for i in range(100):
    bb_numbanjit(200, 5, 0, 0, i, i*2)

print("Numba njit:", round(timeit.default_timer() - starttime,2))

####

starttime = timeit.default_timer()
for i in range(100):
    bb_numbajit(200, 5, 0, 0, i, i*2)

print("Numba jit:", round(timeit.default_timer() - starttime,2))

结果

Pure Function: 3.22
Numba njit: 8.14
Numba jit: 8.15
python scipy numba numerical-integration
1个回答
1
投票

主要问题是您正在计时 Numba 函数的编译时间。事实上,当调用

bb_numbanjit
时,
@njit
装饰器告诉 Numba 声明一个 惰性编译函数,该函数在执行 第一次调用 时进行编译,因此在
integrate.dblquad
中。完全相同的行为也适用于
bb_numbajit
。 Numba 实现速度较慢,因为编译时间与执行时间相比相当长。问题是 Numba 函数是“闭包”,它读取需要新编译的本地参数。解决这个问题的典型方法是向 Numba 函数添加新参数并编译一次。由于这里需要一个闭包,因此可以使用代理闭包。这是一个例子: @njit def f_numba(y, x, q, z, x_loc, y_loc, B, L): return ( 3 * q * z ** 3 / (((x - B / 2 + x_loc) ** 2 + (y - L / 2 + y_loc) ** 2 + z * z) ** 2.5 ) ) def bb_numbanjit(q, z, x_loc, y_loc, B, L): def f_proxy(y, x): return f_numba(y, x, q, z, x_loc, y_loc, B, L) return integrate.dblquad(f_proxy, 0, B, lambda x: 0, lambda x: L)[0]

这比
bb_pure

解决方案快两倍

这个 Numba 解决方案速度并不快的一个原因是 Python 函数调用成本很高,尤其是当有很多参数时。另一个问题是,某些参数似乎是常量,而 Numba 不知道这一点,因为它们是作为运行时参数而不是编译时常量传递的。您可以将全局变量中的常量移动到,让 Numba 进一步优化代码(通过预先计算常量子表达式)。

另请注意,Numba 函数已在内部由代理函数包装。对于这种基本的数值运算,代理函数有点昂贵(它们执行一些类型检查和纯 Python 对象到本机值的转换)。话虽这么说,由于关闭问题,这里没什么可做的。

© www.soinside.com 2019 - 2024. All rights reserved.