在 NumPy ndArray 中基于布尔值查找最长序列的更有效解决方案

我搜索我的 ndArray 以查找基于 True 值的最长系列。是否可以选择在不遍历数组的情况下查找最长系列?


我已经用 numpy.nonzero 编写了自己的解决方案,但可能有更好的解决方案。


import numpy as np

arr = np.array([[[1,2,3,4,5],

                [6,7,8,9,10],

                [11,12,13,14,15],

                [16,17,18,19,20],

                [21,22,23,24,25]],

                [[True,True,True,False,True],

                [True,True,True,True,False],

                [True,True,False,True,True],

                [True,True,True,False,True],

                [True,True,True,False,True]]])


def getIndices(arr):

    arr_to_search = np.nonzero(arr)

    arrs = []

    prev_el0 = 0

    prev_el1 = -1

    activ_long = []

    for i in range(len(arr_to_search[0])):

        if arr_to_search[0][i] == prev_el0:

            if arr_to_search[1][i] != prev_el1 + 1:

                arrs.append(activ_long)

                activ_long = []

        else:

            arrs.append(activ_long)

            activ_long = []

        activ_long.append((arr_to_search[0][i],arr_to_search[1][i]))

        prev_el0 = arr_to_search[0][i]

        prev_el1 = arr_to_search[1][i]


    max_len = len(max(arrs,key=len))

    longest_arr_list = [a for a in arrs if len(a) == max_len]

    return longest_arr_list


print(getIndices(arr[1,:,:]))

print(getIndices(arr[1,:,:].T))



[[(1, 0), (1, 1), (1, 2), (1, 3)]]

[[(0, 0), (0, 1), (0, 2), (0, 3), (0, 4)], [(1, 0), (1, 1), (1, 2), (1, 3), (1, 4)]]


有只小跳蛙
浏览 213回答 1
1回答

叮当猫咪

这是一个 numpy 解决方案,它避免了基于上一个问题的显式循环。我假设布尔数组名为a. 本质上,我们找到行从 0 到 1 或从 1 到 0 变化的索引,并查看它们之间的差异。通过在前后填充 0,我们确保对于从 0 到 1 的每个转换,还有另一个从 1 到 0 的转换。为了方便我处理a,并a.T在同一时间,但你可以分开,如果你想要做他们。m,n = a.shapeA = np.zeros((2*m,n+2))A[:m,1:-1] = aA[m:,1:-1] = a.TdA = np.diff(A)start = np.where(dA>0)end = np.where(dA<0)argmax_run = np.argmax(end[1]-start[1])row = start[0][argmax_run]col_start = start[1][argmax_run]col_end= end[1][argmax_run]-1max_len = col_end - col_start + 1print('max run of length {}'.format(max_len))print('in '+('row' if row<m else'col')+' {}'.format(row%m)+' from '+('col' if row<m else'row')+' {} to {}'.format(col_start,col_end))为了提高性能和存储,我们可以更改A为布尔数组。由于-1和1在dA上面总是成对出现,我们可以找到start和end如下。nz = np.nonzero(dA)start = (nz[0][::2], nz[1][::2])end = (nz[0][1::2], nz[1][1::2])请注意,您可以然后完全删除变量start,end因为它们并不是真正需要的。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python