猿问

将 2d 数组与 3d 数组的每个切片相乘 - Numpy

我正在寻找一种优化的方法来计算 2d 数组与 3d 数组的每个切片的元素乘法(使用 numpy)。

例如:

w = np.array([[1,5], [4,9], [12,15]]) 

y = np.ones((3,2,3))

我想得到一个 3d 数组的结果,其形状与y.

不允许使用 * 运算符进行广播。就我而言,第三维很长,for 循环不方便。


BIG阳
浏览 186回答 1
1回答

慕后森

给定数组import numpy as npw = np.array([[1,5], [4,9], [12,15]])print(w)[[ 1  5] [ 4  9] [12 15]]和y = np.ones((3,2,3))print(y)[[[ 1.  1.  1.]  [ 1.  1.  1.]] [[ 1.  1.  1.]  [ 1.  1.  1.]] [[ 1.  1.  1.]  [ 1.  1.  1.]]]我们可以直接对数组进行乘法运算,z = ( y.transpose() * w.transpose() ).transpose()print(z)[[[  1.   1.   1.]  [  5.   5.   5.]] [[  4.   4.   4.]  [  9.   9.   9.]] [[ 12.  12.  12.]  [ 15.  15.  15.]]]我们可能会注意到,这会产生与 np.einsum('ij,ijk->ijk',w,y) 相同的结果,可能需要更少的努力和开销。
随时随地看视频慕课网APP

相关分类

Python
我要回答