沿 numpy 数组中的轴获取最后一个元素

我需要一个函数来获取 numpy 数组中沿轴的最后一个元素。


例如,如果我有一个数组,


a = np.array([1, 2, 3])

该功能应该像


get_last_elements(a, axis=0)

>>> [3]

get_last_elements(a, axis=1)

>>> [1, 2, 3]

此函数也需要适用于多维数组:


b = np.array([[1, 2],

              [3, 4]])


get_last_elements(b, axis=0)

>>> [[2],

     [4]]

get_last_elements(b, axis=1)

>>> [3, 4]

有没有人有实现它的好主意?


噜噜哒
浏览 205回答 1
1回答

Helenr

您可以使用np.take它:def get_last_elements(a, axis=0):  shape = list(a.shape)  shape[axis] = 1  return np.take(a,-1,axis=axis).reshape(tuple(shape))输出:print(get_last_elements(b, axis=0))[[3 4]]print(get_last_elements(b, axis=1))[[2] [4]]
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python