在numpy中通过另一个数组过滤数组元素

这里有一个简单的例子


import numpy as np

x=np.random.rand(5,5)

k,p = np.where(x>0.5)

k 和 p 是索引数组


现在我有一个应该被视为 m=[0,2,4] 的行列表,所以我需要找到列表 m 中 k 的所有条目。


我想出了一个非常简单但效率低下的可怕解决方案


d = np.array([ (a,b) for a,b in zip(k,p) if a in m])

该解决方案有效,但速度很慢。我正在寻找一种更好、更有效的方法。我需要使用动态调整的 m 进行数百万次这样的操作,因此算法的效率确实是一个关键问题。


慕村9548890
浏览 280回答 3
3回答

holdtom

也许下面更快:d=np.dstack((k,p))[0]print(d[np.isin(d[:,0],m)])

噜噜哒

您可以使用isin()获取可用于索引的布尔掩码k。>>> x=np.random.rand(3,3)>>> xarray([[0.74043564, 0.48328081, 0.82396324],       [0.40693944, 0.24951958, 0.18043229],       [0.46623863, 0.53559775, 0.98956277]])>>> k, p = np.where(x > 0.5)>>> parray([0, 2, 1, 2])>>> karray([0, 0, 2, 2])>>> marray([0, 1])  >>> np.isin(k, m)array([ True,  True, False, False])>>> k[np.isin(k, m)]array([0, 0])

SMILET

怎么样:import numpy as npm = np.array([0, 2, 4])k, p = np.where(x[m, :] > 0.5)k = m[k]print(zip(k, p))这只考虑有趣的行(然后将它们压缩到 2d 索引)。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python