如何检查numpy数组的所有元素是否在另一个numpy数组中

我有两个2D numpy数组,例如:

A = numpy.array([[1, 2, 4, 8], [16, 32, 32, 8], [64, 32, 16, 8]])

B = numpy.array([[1, 2], [32, 32]])

我想拥有所有行,从中A可以找到的任何行的所有元素B。在的行中有2个相同元素的地方B,from的行也A必须至少包含2个。以我的示例为例,我想实现以下目标:

A_filtered = [[1, 2, 4, 8], [16, 32, 32, 8]]

我可以控制值的表示形式,因此我选择了数字表示形式,其中二进制表示形式只有一个位置1(例如:0b000000010b00000010等)。这样,通过使用np.logical_or.reduce()函数,我可以轻松地检查所有类型的值是否都在行中,但是我无法检查连续一行中相同元素的数量是否大于或等于A。我真的希望我可以避免简单的for循环和数组的深拷贝,因为性能对我来说是非常重要的方面。

我如何以有效的方式在numpy中做到这一点?


更新:

这里的解决方案可能有效,但是我认为性能对我来说是一个很大的问题,它A可能真的很大(> 300000行),并且B可能是中等的(> 30):

[set(row).issuperset(hand) for row in A.tolist() for hand in B.tolist()]

更新2:

set()解决方案无法正常工作,因为会set()丢弃所有重复的值。


白板的微信
浏览 561回答 2
2回答

郎朗坤

我认为这应该工作:首先,对数据进行如下编码(假设您的二进制方案似乎暗示了“令牌”的数量有限):制作一个形状[n_rows,n_tokens],dtype int8,其中每个元素都计算标记的数量。以相同的方式对B进行编码,形状为[n_hands,n_tokens]这样就可以对输出进行单个矢量化的表达。matchs =(A [None,:,:]> = B [:, None,:])。all(axis = -1)。(确切地说,如何将此匹配数组映射到所需的输出格式作为练习的内容留给读者,因为问题使它在多个匹配中均未定义)。但是我们在这里谈论的是每个令牌大于10 MB的内存。即使有了32个令牌,这也不应该是不可想象的。但是在这种情况下,最好不要对n_tokens或n_hands或两者上的循环进行矢量化处理;对于小n,for循环很好,或者如果主体中有足够的工作要做,则循环开销微不足道。只要n_tokens和n_hands保持适中,我认为这将是最快的解决方案,如果它停留在纯python和numpy的领域内。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python