使用numba优化容错率的同时循环。

问题描述 投票:0回答:1

我在使用numba进行优化时有一个疑问。我正在编写一个定点迭代函数,计算一个名为gamma的数组的值,该数组满足公式f(gamma)=gamma。我试图用python包Numba来优化这个函数。似乎是这样的。

@jit
def fixed_point(gamma_guess):
    for i in range(17):
        gamma_guess=f(gamma_guess)
    return gamma_guess

Numba能够很好地优化这个函数,因为它知道要执行多少次运算,17次,而且工作速度很快。但是我需要控制我所需要的伽马值的误差公差,我的意思是,一个伽马值和下一个定点迭代得到的伽马值的差值应该小于某个数字epsilon=0.01,然后我试了下

@jit
def fixed_point(gamma_guess):
    err=1000
    gamma_old=gamma_guess.copy()
    while(error>0.01):
        gamma_guess=f(gamma_guess)
        err=np.max(abs(gamma_guess-gamma_old))
        gamma_old=gamma_guess.copy()
    return gamma_guess

它也能计算出想要的结果,但没有上次实现的快,慢了很多。我想这是因为Numba不能很好地优化while循环,因为我们不知道它什么时候会停止。有什么办法可以让我优化一下,运行速度和上次实现的一样快?

编辑。

这是我使用的f

from scipy import fftpack as sp
S=0.01
Amu=0.7
@jit 
def f(gammaa,z,zal,kappa):
    ka=sp.diff(kappa)
    gamma0=gammaa
    for i in range(N):
        suma=0
        for j in range(N):
            if (abs(j-i))%2 ==1:
                if((z[i]-z[j])==0):
                    suma+=(gamma0[j]/(z[i]-z[j]))   
        gamma0[i]=2.0*Amu*np.real(-(zal[i]/z[i])+zal[i]*(1.0/(2*np.pi*1j))*suma*2*h)+S*ka[i]
    return  gamma0

我总是用 np.ones(2048)*0.5 作为初始猜测,我传递给函数的其他参数是 z=np.cos(alphas)+1j*(np.sin(alphas)+0.1) , zal=-np.sin(alphas)+1j*np.cos(alphas) , kappa=np.ones(2048)alphas=np.arange(0,2*np.pi,2*np.pi/2048)

python python-3.x numpy numba
1个回答
0
投票

我做了一个小的测试脚本,看看是否能重现你的错误。

import numba as nb

from IPython import get_ipython
ipython = get_ipython()

@nb.jit(nopython=True)
def f(x):
    return (x+1)/x


def fixed_point_for(x):
    for _ in range(17):
        x = f(x)
    return x

@nb.jit(nopython=True)
def fixed_point_for_nb(x):
    for _ in range(17):
        x = f(x)
    return x

def fixed_point_while(x):
    error=1
    x_old = x
    while error>0.01:
        x = f(x)
        error = abs(x_old-x)
        x_old = x
    return x

@nb.jit(nopython=True)
def fixed_point_while_nb(x):
    error=1
    x_old = x
    while error>0.01:
        x = f(x)
        error = abs(x_old-x)
        x_old = x
    return x

print("for loop without numba:")
ipython.magic("%timeit fixed_point_for(10)")

print("for loop with numba:")
ipython.magic("%timeit fixed_point_for_nb(10)")

print("while loop without numba:")
ipython.magic("%timeit fixed_point_while(10)")

print("for loop with numba:")
ipython.magic("%timeit fixed_point_while_nb(10)")

因为我不知道你的... f 我只是用了我能想到的最简单的稳定功能。然后我跑了测试与和不 numba两次都是 forwhile 循环。我的机器上的结果是。

for loop without numba:
3.35 µs ± 8.72 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
for loop with numba:
282 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
while loop without numba:
1.86 µs ± 7.09 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
for loop with numba:
214 ns ± 1.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

我的想法是:

  • 不可能是你的函数无法优化,因为你的... for 循环是快速的(至少你是这么说的,你有没有测试过不使用 numba?).
  • 可能是你的函数需要更多的循环来收敛,而不是你想象的那样。
  • 我们使用的软件版本不同。我的版本是:
    • numba 0.49.0
    • numpy 1.18.3
    • python 3.8.2
© www.soinside.com 2019 - 2024. All rights reserved.