慕容森
如果中间适合内存,则以下内容应该相当有效import numpy as npfrom scipy.signal import fftconvolve,convolve# examplerng = np.random.default_rng()A = rng.random((5,6,2,3)) B = rng.random((4,3,3,4)) # custom matmulAe,Be = A[...,None],B[:,:,None]shsh = np.maximum(Ae.shape[2:],Be.shape[2:])Ae = np.broadcast_to(Ae,(*Ae.shape[:2],*shsh))Be = np.broadcast_to(Be,(*Be.shape[:2],*shsh))C = fftconvolve(Ae,Be,axes=(0,1),mode='valid').sum(3) # original loop for referenceout = np.zeros_like(C)for row in range(A.shape[2]): for column in range(B.shape[3]): for index in range(B.shape[2]): # Could also be "A.shape[3]" out[:, :, row, column] += convolve( B[:, :, : , column][:, :, index], A[:, :, row, : ][:, :, index], mode='valid' )print(np.allclose(C,out))# True通过批量进行卷积,我们减少了我们必须做的 fft 的总数。如果需要,可以通过使用 对傅里叶空间进行总和缩减,进一步优化速度和内存einsum。不过,这需要手动进行 fft 卷积。