我们可以将
torch.multiprocessing.spawn
与 wandb.sweep
一起使用吗(https://docs.wandb.ai/guides/sweeps)。
torch.multiprocessing.spawn(func, nprocs=world_size, join=True)
我尝试过,但出现错误并且找不到教程。
我正在使用 Python 的标准
multiprocessing
库来生成 wandb.sweep
的代理:
import multiprocessing
import wandb
def init():
'''set up config and start sweep'''
sweep_config = {
'method': 'grid',
'name': 'sweep',
'metric': {
'goal': 'minimize',
'name': 'val_loss'
},
'parameters': {
'net': {'values': ['unet', 'swinunetr']},
'batch_size': {'values': [1, 16, 32, 64, 128]},
'lr': {'values': [1e-1, 1e-2, 1e-3, 1e-4, 1e-5]},
'patch_size': {'values': [32, 64, 128, 256]},
'samples_per_slice': {'values': [1, 2, 4, 8, 16]}
}
}
# count num of combinations for grid sweep
sweep_count = 1
for key in sweep_config['parameters']:
sweep_count *= len(sweep_config['parameters'][key]['values'])
sweep_id = wandb.sweep(sweep=sweep_config, project='proj', )
return sweep_id, sweep_count
def main():
'''training code goes here'''
# connect agent to sweep and import configs/params
wandb.init()
config = {
'net': wandb.config.net,
'batch_size': wandb.config.batch_size,
'lr': wandb.config.lr,
'patch_size': wandb.config.patch_size,
'samples_per_slice': wandb.config.samples_per_slice,
}
print(config)
wandb.log(config)
### ADD YOUR TRAINING CODE HERE
def agent(sweep_id, sweep_count):
'''run agent'''
wandb.agent(sweep_id, function=main, count=sweep_count, project='proj')
if __name__ == '__main__':
# initialize sweep
sweep_id, sweep_count = init()
# spawn and run 4 agents for the sweep
procs = []
for _ in range(4):
p = multiprocessing.Process(target=agent, args=[sweep_id, sweep_count])
p.start()
procs.append(p)
for p in procs:
p.join()
这会将所有
wandb
日志转储到单个终端窗口中,我还没有弄清楚。