我正在尝试使用 dask 编写一个网格搜索实用程序。目标函数调用包含大量数据的类的方法。我正在尝试使用 dask 将计算并行化为多核解决方案,而无需复制原始类/数据帧。我在文档中没有找到任何解决方案,因此我在这里发布一个玩具示例:
import pickle
from dask.distributed import Client, LocalCluster
from multiprocessing import current_process
class TestClass:
def __init__(self):
self.param = 0
def __getstate__(self):
print("I am pickled!")
return self.__dict__
def loss(self, ext_param):
self.param += 1
print(f"{current_process().pid}: {hex(id(self))}: {self.param}: {ext_param} ")
return f"{self.param}_{ext_param}"
def objective_function(param):
return test_instance.loss(param)
if __name__ == '__main__':
test_instance = TestClass()
print(hex(id(test_instance)))
cluster = LocalCluster(n_workers=2)
client = Client(cluster)
futures = client.map(objective_function, range(20))
result = client.gather(futures)
print(result)
# ---- OUTPUT RESULTS ----
# 0x7fe0a5056d30
# I am pickled!
# I am pickled!
# 11347: 0x7fb9bcfa0588: 1: 0
# 11348: 0x7fb9bd0a2588: 1: 1
# 11347: 0x7fb9bcf94240: 1: 2
# 11348: 0x7fb9bd07b6a0: 1: 3
# 11347: 0x7fb9bcf945f8: 1: 4
# ['1_0', '1_1', '1_2', '1_3', '1_4']
我有以下问题:
为什么下面的 pickle 函数被调用两次?
我注意到 map 函数的每次迭代都使用 的新副本test_instance
,正如您可以从每次迭代的不同类地址以及属性test_instance.param
在每次迭代时设置为 0 的事实中看到的那样(此行为与我在这里强调的 multiprocessing.Pool 的标准实现不同)。我假设在每次迭代期间每个进程都会收到腌制类的新副本 - 这是正确的吗?
根据(2),test_instance
计算期间内存中有多少个 的副本?是 1 (对于主线程中的原始实例)+ 1 (腌制副本)+ 2 (每个进程中存在的实例)= 4 吗?有什么办法可以让这个值变成1吗?
我注意到,可以通过使用 Ray 库来获得一些共享内存解决方案,如本 github 问题中所建议的。
猛跑小猪
相关分类