猿问

Python 多处理比常规慢。我该如何改进?

基本上有一个脚本来梳理节点/点的数据集以删除那些重叠的节点/点。实际的脚本更复杂,但我将其缩减为基本上一个简单的重叠检查,它对演示没有任何作用。


我尝试了一些变体,包括锁、队列、池,一次添加一项作业,而不是批量添加。一些最严重的罪犯的速度慢了几个数量级。最终我以最快的速度完成了它。


发送到各个进程的重叠检查算法:


def check_overlap(args):

    tolerance = args['tolerance']

    this_coords = args['this_coords']

    that_coords = args['that_coords']


    overlaps = False

    distance_x = this_coords[0] - that_coords[0]

    if distance_x <= tolerance:

        distance_x = pow(distance_x, 2)

        distance_y = this_coords[1] - that_coords[1]

        if distance_y <= tolerance:

            distance = pow(distance_x + pow(distance_y, 2), 0.5)

            if distance <= tolerance:

               overlaps = True


    return overlaps

处理功能:


def process_coords(coords, num_processors=1, tolerance=1):

    import multiprocessing as mp

    import time


    if num_processors > 1:

        pool = mp.Pool(num_processors)

        start = time.time()

        print "Start script w/ multiprocessing"


    else:

        num_processors = 0

        start = time.time()

        print "Start script w/ standard processing"


    total_overlap_count = 0


    # outer loop through nodes

    start_index = 0

    last_index = len(coords) - 1

    while start_index <= last_index:


        # nature of the original problem means we can process all pairs of a single node at once, but not multiple, so batch jobs by outer loop

        batch_jobs = []


        # inner loop against all pairs for this node

        start_index += 1

        count_overlapping = 0

        for i in range(start_index, last_index+1, 1):


            if num_processors:

                # add job

                batch_jobs.append({

                    'tolerance': tolerance,

                    'this_coords': coords[start_index],

                    'that_coords': coords[i]

                })

尽管如此,非多处理始终在不到 0.4 秒的时间内运行,而多处理我可以达到 3.0 秒以下。我知道这里的算法可能太简单而无法真正获得好处,但考虑到上述情况有近 50 万次迭代,而实际情况有明显更多,多处理慢一个数量级对我来说很奇怪。


我缺少什么/我可以做些什么来改进?


杨魅力
浏览 155回答 1
1回答

慕的地6264312

构建O(N**2)未在序列化代码中使用的 3 元素字典,并通过进程间管道传输它们,是保证多处理无济于事的一种很好的方法;-) 没有什么是免费的 - 一切都需要付出代价。下面是一个重写,无论它是在串行还是多处理模式下运行,它都会执行几乎相同的代码。没有新的 dicts 等。一般来说,越大len(coords),它从多处理中获得的好处就越多。在我的机器上,在 20000 时,多处理运行大约需要挂钟时间的三分之一。关键是所有进程都有自己的coords. 这是通过在创建池时仅传输一次来完成的。这应该适用于所有平台。在 Linux-y 系统上,它可以通过分叉进程继承“神奇地”发生。减少跨进程发送的数据量O(N**2)来O(N)是一个巨大的进步。充分利用多处理将需要更好的负载平衡。照原样,对 的调用check_overlap(i)与coords[i]中的每个值进行比较coords[i+1:]。越大i,越少的工作出现了为它做,并且最大值i传递的只是成本的i进程之间-而将结果发送回-沼泽耗时间在 check_overlap(i)。def init(*args):&nbsp; &nbsp; global _coords, _tolerance&nbsp; &nbsp; _coords, _tolerance = argsdef check_overlap(start_index):&nbsp; &nbsp; coords, tolerance = _coords, _tolerance&nbsp; &nbsp; tsq = tolerance ** 2&nbsp; &nbsp; overlaps = 0&nbsp; &nbsp; start0, start1 = coords[start_index]&nbsp; &nbsp; for i in range(start_index + 1, len(coords)):&nbsp; &nbsp; &nbsp; &nbsp; that0, that1 = coords[i]&nbsp; &nbsp; &nbsp; &nbsp; dx = abs(that0 - start0)&nbsp; &nbsp; &nbsp; &nbsp; if dx <= tolerance:&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; dy = abs(that1 - start1)&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; if dy <= tolerance:&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; if dx**2 + dy**2 <= tsq:&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; overlaps += 1&nbsp; &nbsp; return overlapsdef process_coords(coords, num_processors=1, tolerance=1):&nbsp; &nbsp; global _coords, _tolerance&nbsp; &nbsp; import multiprocessing as mp&nbsp; &nbsp; _coords, _tolerance = coords, tolerance&nbsp; &nbsp; import time&nbsp; &nbsp; if num_processors > 1:&nbsp; &nbsp; &nbsp; &nbsp; pool = mp.Pool(num_processors, initializer=init, initargs=(coords, tolerance))&nbsp; &nbsp; &nbsp; &nbsp; start = time.time()&nbsp; &nbsp; &nbsp; &nbsp; print("Start script w/ multiprocessing")&nbsp; &nbsp; else:&nbsp; &nbsp; &nbsp; &nbsp; num_processors = 0&nbsp; &nbsp; &nbsp; &nbsp; start = time.time()&nbsp; &nbsp; &nbsp; &nbsp; print("Start script w/ standard processing")&nbsp; &nbsp; N = len(coords)&nbsp; &nbsp; if num_processors:&nbsp; &nbsp; &nbsp; &nbsp; total_overlap_count = sum(pool.imap_unordered(check_overlap, range(N)))&nbsp;&nbsp; &nbsp; else:&nbsp; &nbsp; &nbsp; &nbsp; total_overlap_count = sum(check_overlap(i) for i in range(N))&nbsp; &nbsp; print(total_overlap_count)&nbsp; &nbsp; print("&nbsp; time: {0}".format(time.time() - start))if __name__ == "__main__":&nbsp; &nbsp; from random import random&nbsp; &nbsp; coords = []&nbsp; &nbsp; num_coords = 20000&nbsp; &nbsp; spread = 100.0&nbsp; &nbsp; half_spread = 0.5*spread&nbsp; &nbsp; for i in range(num_coords):&nbsp; &nbsp; &nbsp; &nbsp; coords.append([&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; random()*spread-half_spread,&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; random()*spread-half_spread&nbsp; &nbsp; &nbsp; &nbsp; ])&nbsp; &nbsp; process_coords(coords, 1)&nbsp; &nbsp; process_coords(coords, 4)
随时随地看视频慕课网APP

相关分类

Python
我要回答