dask Client.map() 调用期间会发生什么?

我正在尝试使用 dask 编写一个网格搜索实用程序。目标函数调用包含大量数据的类的方法。我正在尝试使用 dask 将计算并行化为多核解决方案,而无需复制原始类/数据帧。我在文档中没有找到任何解决方案,因此我在这里发布一个玩具示例:


import pickle

from dask.distributed import Client, LocalCluster

from multiprocessing import current_process



class TestClass:

    def __init__(self):

        self.param = 0


    def __getstate__(self):

        print("I am pickled!")

        return self.__dict__


    def loss(self, ext_param):

        self.param += 1

        print(f"{current_process().pid}: {hex(id(self))}:  {self.param}: {ext_param} ")

        return f"{self.param}_{ext_param}"



def objective_function(param):

    return test_instance.loss(param)


if __name__ == '__main__':


    test_instance = TestClass()

    print(hex(id(test_instance)))

    cluster = LocalCluster(n_workers=2)

    client = Client(cluster)

    futures = client.map(objective_function, range(20))

    result = client.gather(futures)

    print(result)

    

# ---- OUTPUT RESULTS ----

# 0x7fe0a5056d30

# I am pickled!

# I am pickled!

# 11347: 0x7fb9bcfa0588:  1: 0

# 11348: 0x7fb9bd0a2588:  1: 1

# 11347: 0x7fb9bcf94240:  1: 2

# 11348: 0x7fb9bd07b6a0:  1: 3

# 11347: 0x7fb9bcf945f8:  1: 4 

# ['1_0', '1_1', '1_2', '1_3', '1_4']

我有以下问题:

  1. 为什么下面的 pickle 函数被调用两次?

  2. 我注意到 map 函数的每次迭代都使用 的新副本test_instance,正如您可以从每次迭代的不同类地址以及属性test_instance.param在每次迭代时设置为 0 的事实中看到的那样(此行为与我在这里强调的 multiprocessing.Pool 的标准实现不同)。我假设在每次迭代期间每个进程都会收到腌制类的新副本 - 这是正确的吗?

  3. 根据(2),test_instance计算期间内存中有多少个 的副本?是 1 (对于主线程中的原始实例)+ 1 (腌制副本)+ 2 (每个进程中存在的实例)= 4 吗?有什么办法可以让这个值变成1吗?

我注意到,可以通过使用 Ray 库来获得一些共享内存解决方案,如本 github 问题中所建议的。


幕布斯7119047
浏览 90回答 1
1回答

猛跑小猪

为什么下面的 pickle 函数被调用两次?通常,python 的 pickle 有效地将实例变量和对导入模块中的类的引用捆绑在一起。在 中__main__,这可能不可靠,dask 回退到 cloudpickle(内部也调用 pickle)。在我看来,在第一次尝试腌制之前可能会进行"__main__"检查。distributed.protocol.pickle.dumps在每次迭代期间,每个进程都会收到 pickled 类的新副本是的。每次 dask 运行任务时,它都会反序列化输入,创建实例的 nw 副本。请注意,您的 dask 工作线程可能是通过 fork_server 技术创建的,因此内存不是简单地复制(这是执行操作的安全方法)。您可以在计算之前将实例“分散”给工作人员,他们可以重用其本地副本,但 dask 任务不应该通过改变对象来工作,而是通过返回结果(即功能上)来工作。内存中有多少个 test_instance 副本客户端中 1 个,加上每个正在执行的任务 1 个。序列化版本也可能存在,可能是保存在图中的版本,暂时保存在客户端,然后保存在调度程序上;在反序列化时它也会暂时存在于工作内存中。对于某些类型,零拷贝解/序列化是可能的。如果由于对象的大小而导致任务非常大,那么您绝对应该事先“分散”它们(client.scatter)。有什么办法可以让这个值变成1吗?您可以在进程中运行调度程序和/或工作线程来共享内存,但是,当然,您会失去与 GIL 的并行性。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python