蟒蛇中FFT的循环加速(使用“np.einsum”)

问题:我想加速我的python循环,其中包含很多产品和总结,但我也对任何其他解决方案持开放态度。np.einsum


我的函数采用形状为 (n,n,3) 的向量配置 S(我的情况:n=72),并对 N*N 点的相关函数进行傅里叶变换。相关函数定义为每个向量与其他向量的乘积。这乘以向量位置乘以 kx 和 ky 值的余弦函数。每个位置在最终求和得到k空间中的一个点:i,jp,m


def spin_spin(S,N):

    n= len(S)

    conf = np.reshape(S,(n**2,3))

    chi = np.zeros((N,N))

    kx = np.linspace(-5*np.pi/3,5*np.pi/3,N)

    ky = np.linspace(-3*np.pi/np.sqrt(3),3*np.pi/np.sqrt(3),N)


    x=np.reshape(triangular(n)[0],(n**2))

    y=np.reshape(triangular(n)[1],(n**2))

    for p in range(N):

        for m in range(N):

            for i in range(n**2):

                for j in range(n**2):        

                    chi[p,m] += 2/(n**2)*np.dot(conf[i],conf[j])*np.cos(kx[p]*(x[i]-x[j])+ ky[m]*(y[i]-y[j]))

    return(chi,kx,ky)

我的问题是我需要大约100 * 100个点,这些点由kx * ky表示,并且循环需要几个小时才能完成具有72 * 72个向量的格的工作。计算次数:72 * 72 * 72 * 72 * 100 * 100 我不能使用内置的FFT,因为我的三角形网格,所以我需要一些其他选项来降低这里的计算成本。numpy


我的想法:首先,我意识到将配置重新塑造为向量列表而不是矩阵可以降低计算成本。此外,我使用了numba包,这也降低了成本,但它仍然太慢了。我发现计算这些对象的一个好方法是函数。计算每个向量与每个向量的乘积是通过以下方式完成的:np.einsum


np.einsum('ij,kj -> ik',np.reshape(S,(72**2,3)),np.reshape(S,(72**2,3)))

棘手的部分是计算 .在这里,我想将产品与向量的位置(例如)的形状列表(100,1)连接起来。特别是我真的不知道如何用来实现x方向和y方向的距离。np.cosnp.shape(x)=(72**2,1)np.einsum


要重现代码(可能不需要这样):首先,您需要一个矢量配置。你可以简单地用它来做,或者你用随机向量作为例子:np.ones((72,72,3)


def spherical_to_cartesian(r, theta, phi):

    '''Convert spherical coordinates (physics convention) to cartesian coordinates'''

    sin_theta = np.sin(theta)

    x = r * sin_theta * np.cos(phi)

    y = r * sin_theta * np.sin(phi)

    z = r * np.cos(theta)


    return x, y, z # return a tuple


def random_directions(n, r):

    '''Return ``n`` 3-vectors in random directions with radius ``r``'''

    out = np.empty(shape=(n,3), dtype=np.float64)


    for i in range(n):

        # Pick directions randomly in solid angle

        phi = random.uniform(0, 2*np.pi)

        theta = np.arccos(random.uniform(-1, 1))

        # unpack a tuple

        x, y, z = spherical_to_cartesian(r, theta, phi)

        out[i] = x, y, z

宝慕林4294392
浏览 141回答 2
2回答

翻翻过去那场雪

优化的努姆巴实施代码中的主要问题是使用极小的数据重复调用外部 BLAS 函数。在此代码中,只计算一次它们更有意义,但是如果您必须在循环中执行此计算,请编写Numba实现。例np.dot优化的功能(暴力破解)import numpy as npimport numba as nb@nb.njit(fastmath=True,error_model="numpy",parallel=True)def spin_spin(S,N):    n= len(S)    conf = np.reshape(S,(n**2,3))    chi = np.zeros((N,N))    kx = np.linspace(-5*np.pi/3,5*np.pi/3,N).astype(np.float32)    ky = np.linspace(-3*np.pi/np.sqrt(3),3*np.pi/np.sqrt(3),N).astype(np.float32)    x=np.reshape(triangular(n)[0],(n**2)).astype(np.float32)    y=np.reshape(triangular(n)[1],(n**2)).astype(np.float32)    #precalc some values    fact=nb.float32(2/(n**2))    conf_dot=np.dot(conf,conf.T).astype(np.float32)    for p in nb.prange(N):        for m in range(N):            #accumulating on a scalar is often beneficial            acc=nb.float32(0)            for i in range(n**2):                for j in range(n**2):                            acc+= conf_dot[i,j]*np.cos(kx[p]*(x[i]-x[j])+ ky[m]*(y[i]-y[j]))            chi[p,m]=fact*acc    return(chi,kx,ky)优化功能(删除冗余计算)有很多冗余的计算。这是有关如何删除它们的示例。这也是一个以双精度进行计算的版本。@nb.njit()def precalc(S):    #There may not be all redundancies removed    n= len(S)    conf = np.reshape(S,(n**2,3))    conf_dot=np.dot(conf,conf.T)    x=np.reshape(triangular(n)[0],(n**2))    y=np.reshape(triangular(n)[1],(n**2))    x_s=set()    y_s=set()    for i in range(n**2):        for j in range(n**2):            x_s.add((x[i]-x[j]))            y_s.add((y[i]-y[j]))    x_arr=np.sort(np.array(list(x_s)))    y_arr=np.sort(np.array(list(y_s)))    conf_dot_sel=np.zeros((x_arr.shape[0],y_arr.shape[0]))    for i in range(n**2):        for j in range(n**2):            ii=np.searchsorted(x_arr,x[i]-x[j])            jj=np.searchsorted(y_arr,y[i]-y[j])            conf_dot_sel[ii,jj]+=conf_dot[i,j]    return x_arr,y_arr,conf_dot_sel@nb.njit(fastmath=True,error_model="numpy",parallel=True)def spin_spin_opt_2(S,N):    chi = np.empty((N,N))    n= len(S)    kx = np.linspace(-5*np.pi/3,5*np.pi/3,N)    ky = np.linspace(-3*np.pi/np.sqrt(3),3*np.pi/np.sqrt(3),N)    x_arr,y_arr,conf_dot_sel=precalc(S)    fact=2/(n**2)    for p in nb.prange(N):        for m in range(N):            acc=nb.float32(0)            for i in range(x_arr.shape[0]):                for j in range(y_arr.shape[0]):                            acc+= fact*conf_dot_sel[i,j]*np.cos(kx[p]*x_arr[i]+ ky[m]*y_arr[j])            chi[p,m]=acc    return(chi,kx,ky)@nb.njit()def precalc(S):    #There may not be all redundancies removed    n= len(S)    conf = np.reshape(S,(n**2,3))    conf_dot=np.dot(conf,conf.T)    x=np.reshape(triangular(n)[0],(n**2))    y=np.reshape(triangular(n)[1],(n**2))    x_s=set()    y_s=set()    for i in range(n**2):        for j in range(n**2):            x_s.add((x[i]-x[j]))            y_s.add((y[i]-y[j]))    x_arr=np.sort(np.array(list(x_s)))    y_arr=np.sort(np.array(list(y_s)))    conf_dot_sel=np.zeros((x_arr.shape[0],y_arr.shape[0]))    for i in range(n**2):        for j in range(n**2):            ii=np.searchsorted(x_arr,x[i]-x[j])            jj=np.searchsorted(y_arr,y[i]-y[j])            conf_dot_sel[ii,jj]+=conf_dot[i,j]    return x_arr,y_arr,conf_dot_sel@nb.njit(fastmath=True,error_model="numpy",parallel=True)def spin_spin_opt_2(S,N):    chi = np.empty((N,N))    n= len(S)    kx = np.linspace(-5*np.pi/3,5*np.pi/3,N)    ky = np.linspace(-3*np.pi/np.sqrt(3),3*np.pi/np.sqrt(3),N)    x_arr,y_arr,conf_dot_sel=precalc(S)    fact=2/(n**2)    for p in nb.prange(N):        for m in range(N):            acc=nb.float32(0)            for i in range(x_arr.shape[0]):                for j in range(y_arr.shape[0]):                            acc+= fact*conf_dot_sel[i,j]*np.cos(kx[p]*x_arr[i]+ ky[m]*y_arr[j])            chi[p,m]=acc    return(chi,kx,ky)计时#brute-force%timeit res=spin_spin(S,100)#48 s ± 671 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)#new version%timeit res_2=spin_spin_opt_2(S,100)#5.33 s ± 59.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)%timeit res_2=spin_spin_opt_2(S,1000)#1min 23s ± 2.43 s per loop (mean ± std. dev. of 7 runs, 1 loop each)编辑(安全监控系统检查)import numba as nbimport numpy as np@nb.njit(fastmath=True)def foo(n):    x   = np.empty(n*8, dtype=np.float64)    ret = np.empty_like(x)    for i in range(ret.size):            ret[i] += np.cos(x[i])    return retfoo(1000)if 'intel_svmlcc' in foo.inspect_llvm(foo.signatures[0]):    print("found")else:    print("not found")#found如果有阅读此链接。它应该可以在Linux和Windows上运行,但我还没有在macOS上测试过它。not found

胡子哥哥

这是加快速度的一种方法。我没有开始使用np.einsum,因为对你的循环进行一些调整就足够了。减慢代码速度的主要因素是同一事物的冗余重新计算。此处的嵌套循环是犯罪者:for p in range(N):        for m in range(N):            for i in range(n**2):                for j in range(n**2):                            chi[p,m] += 2/(n**2)*np.dot(conf[i],conf[j])*np.cos(kx[p]*(x[i]-x[j])+ ky[m]*(y[i]-y[j]))它包含大量冗余,多次重新计算向量运算。考虑 np.dot(...):此计算完全独立于点 kx 和 ky。但只有点 kx 和 ky 需要用 m 和 n 进行索引。因此,您可以对所有i和j运行一次点积,并保存结果,而不是为每个m,n重新计算(这将是10,000次!在类似的方法中,不需要在晶格中的每个点重新计算两者之间的向量差异。在每个点上,您都会计算每个矢量距离,而只需要计算一次矢量距离,然后只需将此结果乘以每个晶格点即可。因此,在修复了循环并使用索引(i,j)作为键的字典来存储所有值之后,您只需在i,j上的循环中查找相关值即可。这是我的代码:def spin_spin(S, N):    n = len(S)    conf = np.reshape(S,(n**2, 3))    chi = np.zeros((N, N))    kx = np.linspace(-5*np.pi/3, 5*np.pi/3, N)    ky = np.linspace(-3*np.pi/np.sqrt(3), 3*np.pi/np.sqrt(3), N)    # Minor point; no need to use triangular twice    x, y = triangular(n)    x, y = np.reshape(x,(n**2)), np.reshape(y,(n**2))    # Build a look-up for all the dot products to save calculating them many times    dot_prods = dict()    x_diffs, y_diffs = dict(), dict()    for i, j in itertools.product(range(n**2), range(n**2)):        dot_prods[(i, j)] = np.dot(conf[i], conf[j])        x_diffs[(i, j)], y_diffs[(i, j)] = x[i] - x[j], y[i] - y[j]        # Minor point; improve syntax by converting nested for loops to one line    for p, m in itertools.product(range(N), range(N)):        for i, j in itertools.product(range(n**2), range(n**2)):            # All vector operations are replaced by look ups to the dictionaries defined above            chi[p, m] += 2/(n**2)*dot_prods[(i, j)]*np.cos(kx[p]*(x_diffs[(i, j)]) + ky[m]*(y_diffs[(i, j)]))    return(chi, kx, ky)我现在正在一台像样的机器上使用您提供的尺寸运行此内容,并且i,j上的循环将在两分钟内完成。这只需要发生一次;那么它只是一个在m,n上的循环。每一个大约需要90秒,所以仍然有2-3个小时的运行时间。我欢迎任何关于如何优化cos计算以加快速度的建议!我击中了优化的唾手可得的果实,但为了给人一种速度感,i,j的循环需要2分钟,这样它运行的次数减少了9,999次!
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python