在 numpy 数组中按顺序查找相同整数的更快方法

现在我只是循环使用np.nditer()和比较前一个元素。是否有更快的(矢量化)方法?


额外的好处是我不必总是走到数组的末尾。一旦max_len找到一个序列,我就完成了搜索。


import numpy as np


max_len = 3

streak = 0

prev = np.nan


a = np.array([0, 3, 4, 3, 0, 2, 2, 2, 0, 2, 1])


for c in np.nditer(a):

  if c == prev:

      streak += 1

      if streak == max_len:

          print(c)

          break

  else:

      prev = c

      streak = 1

我想到的替代方案是使用np.diff(),但这只是转移了问题;我们现在正在其结果中寻找一系列零。我也怀疑它会更快,因为它必须计算每个整数的差异,而实际上,序列将在到达列表末尾之前发生。


互换的青春
浏览 107回答 4
4回答

至尊宝的传说

我开发了一个numpy可以工作的 -only 版本,但是经过测试,我发现它的性能很差,因为它不能利用short-circuiting。既然这是你要求的,我在下面描述它。但是,使用稍微修改过的代码版本有更好的方法。numba(请注意,所有这些都返回第一个匹配项的索引a,而不是值本身。我发现这种方法更灵活。)@numba.jit(nopython=True)def find_reps_numba(a, max_len):    streak = 1    val = a[0]    for i in range(1, len(a)):        if a[i] == val:            streak += 1            if streak >= max_len:                return i - max_len + 1        else:            streak = 1            val = a[i]    return -1结果证明这比纯 Python 版本快 100 倍。该numpy版本使用滚动窗口技巧和argmax 技巧。但同样,这甚至比纯 Python 版本慢得多,大约 30 倍。def rolling_window(a, window):    a = numpy.ascontiguousarray(a)  # This approach requires a C-ordered array    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)    strides = a.strides + (a.strides[-1],)    return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)def find_reps_numpy(a, max_len):    windows = rolling_window(a, max_len)    return (windows == windows[:, 0:1]).sum(axis=1).argmax()我针对第一个函数的非抖动版本测试了这两个版本。(我使用 Jupyter 的%%timeit功能进行测试。)a = numpy.random.randint(0, 100, 1000000)%%timeitfind_reps_numpy(a, 3)28.6 ms ± 553 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)%%timeitfind_reps_orig(a, 3)4.04 ms ± 40.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)%%timeitfind_reps_numba(a, 3)8.29 µs ± 89.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)请注意,这些数字可能会有很大差异,具体取决于a函数必须搜索的深度。为了更好地估计预期性能,我们可以每次都重新生成一组新的随机数,但是如果不将那一步包含在时间中,就很难做到这一点。因此,为了在这里进行比较,我将生成随机数组所需的时间包括在内,而无需运行其他任何东西:a = numpy.random.randint(0, 100, 1000000)9.91 ms ± 129 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)a = numpy.random.randint(0, 100, 1000000)find_reps_numpy(a, 3)38.2 ms ± 453 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)a = numpy.random.randint(0, 100, 1000000)find_reps_orig(a, 3)13.7 ms ± 404 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)a = numpy.random.randint(0, 100, 1000000)find_reps_numba(a, 3)9.87 ms ± 124 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)正如你所看到的,find_reps_numba它是如此之快,以至于运行时间的差异numpy.random.randint(0, 100, 1000000)要大得多——因此第一次和最后一次测试之间的加速是虚幻的。所以这个故事的最大寓意是numpy解决方案并不总是最好的。有时甚至纯 Python 也更快。在这些情况下,numbainnopython模式可能是迄今为止最好的选择。

UYOU

假设您正在寻找至少max_len连续出现多次的元素,这是一种基于 NumPy 的方式 -m = np.r_[True,a[:-1]!=a[1:],True]idx0 = np.flatnonzero(m)m2 = np.diff(idx0)>=max_lenout = None # None for no such streak found caseif m2.any():    out = a[idx0[m2.argmax()]]另一个与binary-dilation-from scipy.ndimage.morphology import binary_erosionm = np.r_[False,a[:-1]==a[1:]]m2 = binary_erosion(m, np.ones(max_len-1, dtype=bool))out = Noneif m2.any():    out = a[m2.argmax()]最后,为了完整起见,您还可以查看numba. 您现有的代码可以直接循环a,即for c in a:.

弑天下

您可以创建长度的子数组,max_length每次向右移动一个位置(如ngrams),并检查一个 sub_array 除以的总和max_length是否等于该子数组的第一个元素。如果这是真的,那么你已经找到了长度为整数的连续序列max_length。def get_conseq(array, max_length):    sub_arrays = zip(*[array[i:] for i in range(max_length)])    for e in sub_arrays:        if sum(e) / len(e) == e[0]:            print("Found : {}".format(e))            return e    print("Nothing found")    return []例如,这个[1,2,2,3,4,5]带有的数组max_length = 2将像这样“拆分”: [1,2] [2,2] [2,3] [3,4] [4,5]在第二个元素 上,[2,2]和为 4,除以max_length2,与该子组的第一个元素匹配,函数返回。break如果那是您喜欢做的事情,您可以这样做,而不是像我一样返回。您还可以添加一些规则来捕获边缘情况,使事情变得干净(空数组,max_length优于数组的长度等)。以下是一些示例调用:>>> splits([1,2,3,4,5,6], 2)Nothing found>>> splits([1,2,2,3,4,5,6], 3)Nothing found>>> splits([1,2,3,3,3], 3)Found : [3, 3, 3]>>> splits([1,2,2,3,3], 2)Found : [2, 2]希望这可以帮助 !

白衣染霜花

您可以groupby从itertools包装中使用。import numpy as npfrom itertools import groupbymax_len = 3best = ()a = np.array([0, 3, 4, 3, 0, 2, 2, 2, 0, 2, 1])for k, g in groupby(a):    tup_g = tuple(g)    if tup_g==max_len:        best = tup_g        break    if len(tup_g) > len(best):        best = tup_gbest# returns:(2, 2, 2)
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python