如何在Numpy中将数组展平为矩阵?

我正在寻找一种优雅的方法,根据指定要保留的维度的单个参数将任意形状的数组展平为矩阵。为了说明,我想


def my_func(input, dim):

    # code to compute output

    return output

例如给定一个inputshape 数组2x3x4,output应该是dim=0一个 shape 数组12x2;对于dim=1形状数组8x3;对于dim=2形状数组6x8。如果我只想压平最后一个维度,那么这很容易通过


input.reshape(-1, input.shape[-1])


但我想添加添加的功能dim(优雅地,不经过所有可能的情况+检查 if 条件等)。可以先交换维度,使感兴趣的维度在尾随,然后应用上述操作。


有什么帮助吗?


慕标5832272
浏览 279回答 1
1回答

UYOU

我们可以置换轴并重塑 -# a is input array; axis is input axis/dimnp.moveaxis(a,axis,-1).reshape(-1,a.shape[axis])从功能上讲,它基本上是将指定的轴推到后面,然后重塑保持该轴的长度以形成第二个轴并合并其余的轴以形成第一个轴。样品运行 -In [32]: a = np.random.rand(2,3,4)In [33]: axis = 0In [34]: np.moveaxis(a,axis,-1).reshape(-1,a.shape[axis]).shapeOut[34]: (12, 2)In [35]: axis = 1In [36]: np.moveaxis(a,axis,-1).reshape(-1,a.shape[axis]).shapeOut[36]: (8, 3)In [37]: axis = 2In [38]: np.moveaxis(a,axis,-1).reshape(-1,a.shape[axis]).shapeOut[38]: (6, 4)
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python