如何使用外部模块在Joblib中的线程之间共享变量

问题描述 投票:1回答:1

我正在尝试修改sklearn源代码。特别是,我正在修改GridSearch源代码,以使评估不同模型配置的单独进程/线程在它们之间共享一个变量。我需要每个线程/进程在运行时读取/更新该变量,以便根据获得的其他线程来修改其执行。更具体地说,我要共享的参数是best,在下面的代码段中:

out = parallel(delayed(_fit_and_score)(clone(base_estimator), X, y, best, self.early,train=train, test=test,parameters=parameters,**fit_and_score_kwargs) for parameters, (train, test) in product(candidate_params, cv.split(X, y, groups))) 

注意,_ fit_and_score函数位于单独的模块中。Sklearn利用joblib进行并行化,但是我无法理解如何使用外部模块有效地做到这一点。在joblib文档中提供了以下代码:

>>> shared_set = set()
>>> def collect(x):
...    shared_set.add(x)
...
>>> Parallel(n_jobs=2, require='sharedmem')(
...     delayed(collect)(i) for i in range(5))
[None, None, None, None, None]
>>> sorted(shared_set)
[0, 1, 2, 3, 4]

但是我无法理解如何使其在我的上下文中运行。您可以在这里找到源代码:

scikit-learn joblib
1个回答
0
投票

您可以使用python的Manager(https://docs.python.org/3/library/multiprocessing.html#multiprocessing.sharedctypes.multiprocessing.Manager),例如,简单的代码:

from joblib import Parallel, delayed
from multiprocessing import Manager

manager = Manager()
q = manager.Namespace()
q.flag = False

def test(i, q):
    #update shared var in 0 process
    if i == 0:
        q.flag = True

    # do nothing for few seconds
    for n in range(100000000):
        if q.flag == True:
            return f'process {i} was updated'

    return 'process {i} was not updated'

out = Parallel(n_jobs=4)(delayed(test)(i, q) for i in range(4))

输出:

['process 0 was updated',
 'process 1 was updated',
 'process 2 was updated',
 'process 3 was updated']
© www.soinside.com 2019 - 2024. All rights reserved.