从 numba 修饰函数调用 root_scalar

问题描述 投票:0回答:1

find_root
nb.jit
修饰时,以下代码将失败。这是一个玩具示例,但其想法是能够找到值数组的标量函数(或可能使用
root
的多元函数)的根并将它们存储在 numpy 数组中。

错误信息:

TypingError: cannot determine Numba type of <class 'function'>

import numba as nb
import numpy as np
from scipy.optimize import root_scalar

a = 3.0
b = 1.0
c = -10.5

@nb.jit(nopython=True)
def f(x):
    return a*x**2 + b*x + c

@nb.jit(nopython=True)
def fprime(x):
    return 2*a*x + b

@nb.jit(nopython=True)
def fprime2(x):
    return 2*a

@nb.jit(nopython=True) # <-- Commenting this line makes the code work but it is slow
def findroot(arr): 
    for i in range(len(arr)):
        arr[i] = root_scalar(f, fprime=fprime, fprime2=fprime2, x0=0).root
        
if __name__ == '__main__':
    arr = np.zeros(20, np.float)
    
    import timeit  
    start = timeit.time.process_time()
    findroot(arr)
    end = timeit.time.process_time()
    print(end - start)
numpy scipy numeric numba
1个回答
0
投票

从 numba 修饰函数调用 root_scalar

我认为您需要首先成功进行 JIT 编译

root_scalar
,这可能需要一些努力。由于您提供的是
fprime
而不是括号,因此它将在幕后使用
newton
,因此您确实需要对其进行 JIT 编译并直接调用它。但是:

...这个想法是能够找到标量函数的根...对于值数组并将它们存储在 numpy 数组中。

你已经可以用

newton
做到了。

import numba as nb
import numpy as np
from scipy.optimize import newton

a = 3.0
b = 1.0
c = -10.5

def f(x):
    return a*x**2 + b*x + c

def fprime(x):
    return 2*a*x + b

def fprime2(x):
    return np.full_like(x, 2*a)

x0 = 0
%timeit newton(f, x0=x0, fprime=fprime, fprime2=fprime2)
# 289 µs ± 8.27 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

x0 = np.zeros(1000)
%timeit newton(f, x0=x0, fprime=fprime, fprime2=fprime2)
# 416 µs ± 13.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

举一个更有趣的例子:

import matplotlib.pyplot as plt
c = np.linspace(-10.5, -1000, 1000)
root = newton(f, x0=x0, fprime=fprime, fprime2=fprime2)
plt.plot(c, root)

© www.soinside.com 2019 - 2024. All rights reserved.