Python 中“矩阵乘法”的更快定义

我需要从头开始定义矩阵乘法,而不是将每个常数相乘,每个常数实际上是另一个数组,任何两个数组都需要“卷积”在一起(我认为没有必要在这里定义卷积) .

我制作了一张图片,希望能更好地解释我想说的话:

http://img2.mukewang.com/63ec887800014ae103180316.jpg

我必须使用的代码是这样的:


for row in range(arr1.shape[2]):

    for column in range(arr2.shape[3]):

        for index in range(arr2.shape[2]): # Could also be "arr1.shape[3]"

            out[:, :, row, column] += convolve(

                arr2[:, :, :  , column][:, :, index],

                arr1[:, :, row, :     ][:, :, index]

            )

然而,这种方法对我来说非常慢,所以我想知道是否有更快的方法来做到这一点。


慕尼黑的夜晚无繁华
浏览 101回答 1
1回答

慕容森

如果中间适合内存,则以下内容应该相当有效import numpy as npfrom scipy.signal import fftconvolve,convolve# examplerng = np.random.default_rng()A = rng.random((5,6,2,3))                    B = rng.random((4,3,3,4))                    # custom matmulAe,Be = A[...,None],B[:,:,None]shsh = np.maximum(Ae.shape[2:],Be.shape[2:])Ae = np.broadcast_to(Ae,(*Ae.shape[:2],*shsh))Be = np.broadcast_to(Be,(*Be.shape[:2],*shsh))C = fftconvolve(Ae,Be,axes=(0,1),mode='valid').sum(3)         # original loop for referenceout = np.zeros_like(C)for row in range(A.shape[2]):    for column in range(B.shape[3]):        for index in range(B.shape[2]): # Could also be "A.shape[3]"            out[:, :, row, column] += convolve(                B[:, :, :  , column][:, :, index],                A[:, :, row, :     ][:, :, index],                mode='valid'            )print(np.allclose(C,out))# True通过批量进行卷积,我们减少了我们必须做的 fft 的总数。如果需要,可以通过使用 对傅里叶空间进行总和缩减,进一步优化速度和内存einsum。不过,这需要手动进行 fft 卷积。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python