使用 numba 优化 while 循环以实现容错

我在使用 numba 进行优化时有疑问。我正在编写一个定点迭代来计算一个名为 gamma 的数组的值,它满足方程 f(gamma)=gamma。我正在尝试使用 python 包 Numba 优化此功能。看起来如下。


@jit

def fixed_point(gamma_guess):

    for i in range(17):

        gamma_guess=f(gamma_guess)

    return gamma_guess


Numba 能够很好地优化这个功能,因为它知道它将执行多少次操作,17 次,并且运行速度很快。但是我需要控制我想要的伽玛的误差容限,我的意思是,一个伽玛和下一个通过定点迭代获得的伽玛之差应该小于某个数字epsilon = 0.01,然后我尝试了


@jit

def fixed_point(gamma_guess):

    err=1000

    gamma_old=gamma_guess.copy()

    while(error>0.01):

        gamma_guess=f(gamma_guess)

        err=np.max(abs(gamma_guess-gamma_old))

        gamma_old=gamma_guess.copy()

    return gamma_guess

它也可以工作并计算所需的结果,但不如上次实现快,它要慢得多。我认为这是因为 Numba 无法很好地优化 while 循环,因为我们不知道它何时会停止。有没有办法可以优化它并像上次实现一样快地运行?


编辑:


这是我正在使用的 f


from scipy import fftpack as sp

S=0.01

Amu=0.7

@jit 

def f(gammaa,z,zal,kappa):

    ka=sp.diff(kappa)

    gamma0=gammaa

    for i in range(N):

        suma=0

        for j in range(N):

            if (abs(j-i))%2 ==1:

                if((z[i]-z[j])==0):

                    suma+=(gamma0[j]/(z[i]-z[j]))   

        gamma0[i]=2.0*Amu*np.real(-(zal[i]/z[i])+zal[i]*(1.0/(2*np.pi*1j))*suma*2*h)+S*ka[i]

    return  gamma0

我总是用作初始猜测,np.ones(2048)*0.5传递给我的函数的其他参数是z=np.cos(alphas)+1j*(np.sin(alphas)+0.1)、zal=-np.sin(alphas)+1j*np.cos(alphas)和kappa=np.ones(2048)alphas=np.arange(0,2*np.pi,2*np.pi/2048)


紫衣仙女
浏览 128回答 1
1回答

函数式编程

我做了一个小测试脚本,看看我是否可以重现你的错误:import numba as nbfrom IPython import get_ipythonipython = get_ipython()@nb.jit(nopython=True)def f(x):    return (x+1)/xdef fixed_point_for(x):    for _ in range(17):        x = f(x)    return x@nb.jit(nopython=True)def fixed_point_for_nb(x):    for _ in range(17):        x = f(x)    return xdef fixed_point_while(x):    error=1    x_old = x    while error>0.01:        x = f(x)        error = abs(x_old-x)        x_old = x    return x@nb.jit(nopython=True)def fixed_point_while_nb(x):    error=1    x_old = x    while error>0.01:        x = f(x)        error = abs(x_old-x)        x_old = x    return xprint("for loop without numba:")ipython.magic("%timeit fixed_point_for(10)")print("for loop with numba:")ipython.magic("%timeit fixed_point_for_nb(10)")print("while loop without numba:")ipython.magic("%timeit fixed_point_while(10)")print("for loop with numba:")ipython.magic("%timeit fixed_point_while_nb(10)")由于我不了解您的f情况,因此我只使用了我能想到的最简单的稳定功能。然后,我在使用和不numba使用循环的情况下运行测试。我机器上的结果是:forwhilefor loop without numba:3.35 µs ± 8.72 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)for loop with numba:282 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)while loop without numba:1.86 µs ± 7.09 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)for loop with numba:214 ns ± 1.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)产生以下想法:不可能,你的函数是不可优化的,因为你的for循环很快(至少你是这么说的;你没有测试过numba吗?)。如您所想,您的函数可能需要更多的循环才能收敛我们使用不同的软件版本。我的版本是:numba  0.49.0numpy  1.18.3python 3.8.2
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python