我想并行 torch.nn.ModuleList,例如
nets = torch.nn.ModuleList([net1, net2])
nets = torch.nn.parallel.DistributedDataParallel(nets)
但是当我使用 net1 进行下一步时
x = nets[0](img)
我得到一个错误:“DistributedDataParallel”对象不可订阅。
我试过了
x = nets.module[0](img)
它可以工作,但我不确定 DDP 是在工作还是只在第一个 GPU 上运行?
如果它只是在第一个GPU上运行,如何将前向步骤与net1并行
从这个问题来看,你打算并行化的内容不是很清楚。
nets.modules[0]
从内部列表中捕获对原始网络net1
的引用,因此绕过 DDP(以及绕过 net2
)。
(Distributed)DataParallel 通过拆分输入数据来跨设备拆分输入。每个设备复制一次(整个)模型。 ModuleList 是模块列表,通常不会传递给 DDP。
一些可能性:
net1(net2(input))
,那么您可能正在寻找nn.Sequential
而不是nn.ModuleList
。那么分布式模型的应用就是nets(img)
。传递给 DDP 的模块本身应该有一个可调用的 forward 方法。.to(device)
在层之间跨设备移动输入,或者如果在多个主机上运行则使用 PyTorch 的 RPC 机制。