在 torch.distributed 中使用 async all-reduce 时进程会被阻塞

问题描述 投票:0回答:1

我尝试在 torch.distributed 中使用异步 all-reduce,这是在 PyTorch Docs 中介绍的。但是,我发现虽然我设置了 async_op=True,但进程仍然被阻止。我哪里做错了?

我复制了Docs提供的示例代码,添加了一些睡眠和打印命令来检查它是否阻塞。

import torch
import torch.distributed as dist
import os
import time

rank = int(os.getenv('RANK', '0'))
dist.init_process_group(
        backend='nccl',
        world_size=2,
        rank=rank,
        )

output = torch.tensor([rank]).cuda(rank)
if rank == 1:
       time.sleep(5)

s = torch.cuda.Stream()
print(f"Process {rank}: begin aysnc all-reduce", flush=True)
handle = dist.all_reduce(output, async_op=True)
# Wait ensures the operation is enqueued, but not necessarily complete.
handle.wait()
# Using result on non-default stream.
print(f"Process {rank}: async check")
with torch.cuda.stream(s):
    s.wait_stream(torch.cuda.default_stream())
    output.add_(100)
if rank == 0:
    # if the explicit call to wait_stream was omitted, the output below will be
    # non-deterministically 1 or 101, depending on whether the allreduce overwrote
    # the value after the add completed.
    print(output)

输出:

Process 0: begin aysnc all-reduce
Process 1: begin aysnc all-reduce
Process 1: async check
Process 0: async check
tensor([101], device=‘cuda:0’)

我希望“Process 0: async check”应该在“Process 1: begin aysnc all-reduce”之前打印。

asynchronous pytorch distributed-computing
1个回答
0
投票

我猜原因是你在

handle.wait()
之前使用了
print(f"Process {rank}: async check")
handle.wait()
将阻塞进程,直到 allreduce 完成,这会同步rank0和rank1。我认为只有当您将
handle.wait()
放在
print(f"Process {rank}: async check")
之后时,您才能预期“Process 0: async check”应该在“Process 1: begin aysnc all-reduce”之前打印。

© www.soinside.com 2019 - 2024. All rights reserved.