当
find_root
用 nb.jit
修饰时,以下代码将失败。这是一个玩具示例,但其想法是能够找到值数组的标量函数(或可能使用 root
的多元函数)的根并将它们存储在 numpy 数组中。
错误信息:
TypingError: cannot determine Numba type of <class 'function'>
import numba as nb
import numpy as np
from scipy.optimize import root_scalar
a = 3.0
b = 1.0
c = -10.5
@nb.jit(nopython=True)
def f(x):
return a*x**2 + b*x + c
@nb.jit(nopython=True)
def fprime(x):
return 2*a*x + b
@nb.jit(nopython=True)
def fprime2(x):
return 2*a
@nb.jit(nopython=True) # <-- Commenting this line makes the code work but it is slow
def findroot(arr):
for i in range(len(arr)):
arr[i] = root_scalar(f, fprime=fprime, fprime2=fprime2, x0=0).root
if __name__ == '__main__':
arr = np.zeros(20, np.float)
import timeit
start = timeit.time.process_time()
findroot(arr)
end = timeit.time.process_time()
print(end - start)
从 numba 修饰函数调用 root_scalar
我认为您需要首先成功进行 JIT 编译
root_scalar
,这可能需要一些努力。由于您提供的是 fprime
而不是括号,因此它将在幕后使用 newton
,因此您确实需要对其进行 JIT 编译并直接调用它。但是:
...这个想法是能够找到标量函数的根...对于值数组并将它们存储在 numpy 数组中。
那你已经可以用
newton
做到了。
import numba as nb
import numpy as np
from scipy.optimize import newton
a = 3.0
b = 1.0
c = -10.5
def f(x):
return a*x**2 + b*x + c
def fprime(x):
return 2*a*x + b
def fprime2(x):
return np.full_like(x, 2*a)
x0 = 0
%timeit newton(f, x0=x0, fprime=fprime, fprime2=fprime2)
# 289 µs ± 8.27 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
x0 = np.zeros(1000)
%timeit newton(f, x0=x0, fprime=fprime, fprime2=fprime2)
# 416 µs ± 13.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
举一个更有趣的例子:
import matplotlib.pyplot as plt
c = np.linspace(-10.5, -1000, 1000)
root = newton(f, x0=x0, fprime=fprime, fprime2=fprime2)
plt.plot(c, root)