pytorch DDP支持torch.nn.ModuleList吗?

问题描述 投票:0回答:1

我想并行 torch.nn.ModuleList,例如

nets = torch.nn.ModuleList([net1, net2])
nets = torch.nn.parallel.DistributedDataParallel(nets)

但是当我使用 net1 进行下一步时

x = nets[0](img)

我得到一个错误:“DistributedDataParallel”对象不可订阅。

我试过了

x = nets.module[0](img)

它可以工作,但我不确定 DDP 是在工作还是只在第一个 GPU 上运行?

如果它只是在第一个GPU上运行,如何将前向步骤与net1并行

python machine-learning deep-learning pytorch distributed-computing
1个回答
0
投票

从这个问题来看,你打算并行化的内容不是很清楚。

nets.modules[0]
从内部列表中捕获对原始网络
net1
的引用,因此绕过 DDP(以及绕过
net2
)。

(Distributed)DataParallel 通过拆分输入数据来跨设备拆分输入。每个设备复制一次(整个)模型。 ModuleList 是模块列表,通常不会传递给 DDP。

一些可能性:

  • 如果您的意图是在同一数据上训练两个完全不同的模型,那么最好分开进行,例如可能有两个训练循环。您不需要将输入数据分块来执行此操作。
  • 如果列表中的模块要在更大的模型中作为层缝合在一起,例如
    net1(net2(input))
    ,那么您可能正在寻找
    nn.Sequential
    而不是
    nn.ModuleList
    。那么分布式模型的应用就是
    nets(img)
    。传递给 DDP 的模块本身应该有一个可调用的 forward 方法。
  • 如果您打算在单个模型中跨模块/层并行化,您实际上可能正在寻找“模型并行”而不是“数据并行”。有一些模型并行教程使用
    .to(device)
    在层之间跨设备移动输入,或者如果在多个主机上运行则使用 PyTorch 的 RPC 机制。
© www.soinside.com 2019 - 2024. All rights reserved.