我们可以使用 torch.multiprocessing.spawn 进行 wandb 扫描超参数调整吗?

问题描述 投票:0回答:1

我们可以将

torch.multiprocessing.spawn
wandb.sweep
一起使用吗(https://docs.wandb.ai/guides/sweeps)。

torch.multiprocessing.spawn(func, nprocs=world_size, join=True)

我尝试过,但出现错误并且找不到教程。

pytorch wandb pytorch-distributions
1个回答
0
投票

我正在使用 Python 的标准

multiprocessing
库来生成
wandb.sweep
的代理:

import multiprocessing
import wandb
    
def init():
    '''set up config and start sweep'''
    sweep_config = {
        'method': 'grid',
        'name': 'sweep',
        'metric': {
            'goal': 'minimize',
            'name': 'val_loss'
        },
        'parameters': {
            'net': {'values': ['unet', 'swinunetr']},
            'batch_size': {'values': [1, 16, 32, 64, 128]},
            'lr': {'values': [1e-1, 1e-2, 1e-3, 1e-4, 1e-5]},
            'patch_size': {'values': [32, 64, 128, 256]},
            'samples_per_slice': {'values': [1, 2, 4, 8, 16]}
        }
    }
    
    # count num of combinations for grid sweep
    sweep_count = 1
    for key in sweep_config['parameters']:
        sweep_count *= len(sweep_config['parameters'][key]['values'])
    
    sweep_id = wandb.sweep(sweep=sweep_config, project='proj', )
    
    return sweep_id, sweep_count
    
def main():
    '''training code goes here'''
    # connect agent to sweep and import configs/params
    wandb.init()
    config = {
        'net': wandb.config.net,
        'batch_size': wandb.config.batch_size,
        'lr': wandb.config.lr,
        'patch_size': wandb.config.patch_size,
        'samples_per_slice': wandb.config.samples_per_slice,
    }
    print(config)
    wandb.log(config)
    ### ADD YOUR TRAINING CODE HERE
    
def agent(sweep_id, sweep_count):
    '''run agent'''
    wandb.agent(sweep_id, function=main, count=sweep_count, project='proj')
    
if __name__ == '__main__':
    # initialize sweep
    sweep_id, sweep_count = init()
    # spawn and run 4 agents for the sweep
    procs = []
    for _ in range(4):
        p = multiprocessing.Process(target=agent, args=[sweep_id, sweep_count])
        p.start()
        procs.append(p)
    for p in procs:
        p.join()

这会将所有

wandb
日志转储到单个终端窗口中,我还没有弄清楚。

© www.soinside.com 2019 - 2024. All rights reserved.