Python:沿 3d 矩阵中的一个轴查找最大值,使非最大值为零

我有一个如下所示的 3-d 矩阵,并希望沿轴 1 取最大值,并将所有非最大值保持为零。


A = np.random.rand(3,3,2)


  [[[0.34444547, 0.50260393],

    [0.93374423, 0.39021899],

    [0.94485653, 0.9264881 ]],


   [[0.95446736, 0.335068  ],

    [0.35971558, 0.11732342],

    [0.72065402, 0.36436023]],


   [[0.56911013, 0.04456443],

    [0.17239996, 0.96278067],

    [0.26004909, 0.06767436]]]

想要的结果:


   [[0         , 0         ],

    [0         , 0         ],

    [0.94485653, 0.9264881]],


   [[0.95446736, 0          ],

    [0         , 0          ],

    [0         , 0.36436023]],


   [[0.56911013, 0         ],

    [0         , 0.96278067],

    [0         , 0         ]]])

我试过了:


B = np.zeros_like(A)  #return matrix of zero with same shape as A


max_idx = np.argmax(A, axis=1) #index along axis 1 with max value


    array([[2, 0],

           [2, 2],

           [0, 2],

           [0, 1]])


C = np.max(A, axis=1, keepdims = True) #gives a (4,1,2) matrix of max value along axis 1


    array([[[0.95377958, 0.92940525]],

           [[0.94485653, 0.9264881 ]],

           [[0.95446736, 0.36436023]],

           [[0.56911013, 0.96278067]]])

但我不知道如何将这些想法结合在一起以获得我想要的输出。请帮忙!!


跃然一笑
浏览 165回答 1
1回答

Cats萌萌

你可以得到3维指数的最大值从max_idx。中的值max_idx是最大值沿轴 1 的索引。有六个值,因为您的其他轴是 3 和 2 (3 x 2 = 6)。您只需要了解 numpy 通过它们获取其他每个轴的索引的顺序。您首先遍历最后一个轴:d0, d1, d2 = A.shapea0 = [i for i in range(d0) for _ in range(d2)]   # [0, 0, 1, 1, 2, 2]a1 = max_idx.flatten()                           # [2, 2, 0, 2, 0, 1]a2 = [k for _ in range(d0) for k in range(d2)]   # [0, 1, 0, 1, 0, 1]B[a0, a1, a2] = A[a0, a1, a2]输出:array([[[0.        , 0.        ],        [0.        , 0.        ],        [0.94485653, 0.9264881 ]],       [[0.95446736, 0.        ],        [0.        , 0.        ],        [0.        , 0.36436023]],       [[0.56911013, 0.        ],        [0.        , 0.96278067],        [0.        , 0.        ]]])
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python