我尝试在 torch.distributed 中使用异步 all-reduce,这是在 PyTorch Docs 中介绍的。但是,我发现虽然我设置了 async_op=True,但进程仍然被阻止。我哪里做错了?
我复制了Docs提供的示例代码,添加了一些睡眠和打印命令来检查它是否阻塞。
import torch
import torch.distributed as dist
import os
import time
rank = int(os.getenv('RANK', '0'))
dist.init_process_group(
backend='nccl',
world_size=2,
rank=rank,
)
output = torch.tensor([rank]).cuda(rank)
if rank == 1:
time.sleep(5)
s = torch.cuda.Stream()
print(f"Process {rank}: begin aysnc all-reduce", flush=True)
handle = dist.all_reduce(output, async_op=True)
# Wait ensures the operation is enqueued, but not necessarily complete.
handle.wait()
# Using result on non-default stream.
print(f"Process {rank}: async check")
with torch.cuda.stream(s):
s.wait_stream(torch.cuda.default_stream())
output.add_(100)
if rank == 0:
# if the explicit call to wait_stream was omitted, the output below will be
# non-deterministically 1 or 101, depending on whether the allreduce overwrote
# the value after the add completed.
print(output)
输出:
Process 0: begin aysnc all-reduce
Process 1: begin aysnc all-reduce
Process 1: async check
Process 0: async check
tensor([101], device=‘cuda:0’)
我希望“Process 0: async check”应该在“Process 1: begin aysnc all-reduce”之前打印。
我猜原因是你在
handle.wait()
之前使用了print(f"Process {rank}: async check")
。 handle.wait()
将阻塞进程,直到 allreduce 完成,这会同步rank0和rank1。我认为只有当您将 handle.wait()
放在 print(f"Process {rank}: async check")
之后时,您才能预期“Process 0: async check”应该在“Process 1: begin aysnc all-reduce”之前打印。