我正在寻找一种优雅的方法,根据指定要保留的维度的单个参数将任意形状的数组展平为矩阵。为了说明,我想
def my_func(input, dim):
# code to compute output
return output
例如给定一个inputshape 数组2x3x4,output应该是dim=0一个 shape 数组12x2;对于dim=1形状数组8x3;对于dim=2形状数组6x8。如果我只想压平最后一个维度,那么这很容易通过
input.reshape(-1, input.shape[-1])
但我想添加添加的功能dim(优雅地,不经过所有可能的情况+检查 if 条件等)。可以先交换维度,使感兴趣的维度在尾随,然后应用上述操作。
有什么帮助吗?
UYOU
相关分类