猿问

multiprocessing.map 的替代方案,不会存储函数的返回值

我习惯于multiprocessing.imap_unordered同时运行一个函数,但我的 RAM 使用量不断增加。


问题如下:我有数百万个itertools.product数据组合(使用 创建)需要传递给函数。然后,该函数将使用 SVM 计算分数,然后存储分数和当前组合。该函数不会返回任何值,它只会计算分数并将其存储在共享值中。我不需要所有其他组合,只需要最好的组合。


通过使用imap_unorderedRAM 使用量不断增加,直到由于 RAM 不足而崩溃。我想发生这种情况是因为它将imap存储函数的结果,它不会返回任何值,但可能会保留 aNone或Null值?


这是一个示例代码:


from functools import partial

import itertools

import multiprocessing

import time



def svm(input_data, params):


    # Copy the data to avoid changing the original data

    # as input_data is a reference to a pandas dataframe

    # and I need to shift columns up and down.

    dataset = input_data.copy()


    # Use svm here to analyse data

    score = sum(dataset) + sum(params)  # simulate score of svm


    # Simulate a process that takes a bit of time

    time.sleep(0.5)


    return (score, params)



if __name__ == "__main__":

    

    # Without this, multiprocessing gives error

    multiprocessing.freeze_support()


    # Set the number of worker processes

    # Empty for all the cores

    # Int for number of processes

    pool = multiprocessing.Pool()


    # iterable settings

    total_combinations = 2

    total_features = 45


    # Keep track of best score

    best_score = -1000

    best_param = [0 for _ in range(total_features)]


    input_data = [x * x for x in range(10000)]


    # Create a partial function with the necessary args

    func = partial(svm, input_data)

    params = itertools.product(range(total_combinations), repeat=total_features)



在此示例中,您会注意到 RAM 使用量随着时间的推移而增加。尽管在本例中它并不多,但如果您单独放置一天或其他时间(通过增加可迭代的范围),它将达到 GB 的 RAM。正如我所说,我有数百万种组合。


我应该如何解决这个问题?是否有替代方案,imap根本不会存储有关该功能的任何内容?我应该只创建Processes而不是使用吗Pool?难道是因为我正在复制数据集,后来垃圾收集器没有清理它?


暮色呼如
浏览 112回答 2
2回答

牧羊人nacy

您可以使用apply或apply_async

慕容3067478

import objgraph我已经使用和打印跟踪了内存使用情况objgraph.show_most_common_types(limit=20)。我注意到元组和列表的数量在子进程的持续时间内不断增加。为了解决这个问题,我更改了maxtasksperchild在Pool一段时间后强制关闭进程并因此释放内存。from functools import partialimport itertoolsimport multiprocessingimport randomimport time# Tracing memory leaksimport objgraphdef svm(input_data, params):    # Copy the data to avoid changing the original data    # as input_data is a reference to a pandas dataframe.    dataset = input_data.copy()    # Use svm here to analyse data    score = sum(dataset) + sum(params)  # simulate score of svm    # Simulate a process that takes a bit of time    time.sleep(0.5)    return (score, params)if __name__ == "__main__":    # iterable settings    total_combinations = 2    total_features = 12    # Keep track of best score    best_score = -1000    best_param = [0 for _ in range(total_features)]    # Simulate a dataframe with random data    input_data = [random.random() for _ in range(100000)]    # Create a partial function with the necessary args    func = partial(svm, input_data)    params = itertools.product(range(total_combinations), repeat=total_features)    # Without this, multiprocessing gives error    multiprocessing.freeze_support()    # Set the number of worker processes    # Empty for all the cores    # Int for number of processes    with multiprocessing.Pool(maxtasksperchild=5) as pool:        # Calculate scores concurrently        # As the iterable is in the order of millions, this value        # will get continuously large until it uses all available        # memory as the map stores the results, that in this case        # it's not needed.        for score, param in pool.imap_unordered(func, iterable=params, chunksize=10):            if score > best_score:                best_score = score                best_param = param                # print(best_score)            # Count the number of objects in the memory            # If the number of objects keep increasing, it's a memory leak            print(objgraph.show_most_common_types(limit=20))    # Wait for all the processes to terminate their tasks    pool.close()    pool.join()    print(best_score)    print(best_param)
随时随地看视频慕课网APP

相关分类

Python
我要回答