分布式Tensorflow:谁应用参数更新?

问题描述 投票:8回答:1

我已经使用过TensorFlow但是对于训练模型分发TensorFlow是新手。我的理解是,当前的最佳实践支持使用异步更新的数据并行模型:

Google Brain团队于2016年4月发布的一篇论文对各种方法进行了基准测试,发现使用一些备用复制品进行同步更新的数据并行性是最有效的,不仅收敛速度更快,而且还能产生更好的模型。 - Hands-On Machine Learning with Scikit-Learn and Tensorflow的第12章。

现在,我对进一步阅读这个架构的困惑是弄清楚哪个组件应用了参数更新:工作者还是参数服务器?

在下面的插图中,我很清楚工人计算梯度dJ / dw(损失J相对于参数权重w的梯度)。但谁应用梯度下降更新规则?

enter image description here

有点令人困惑的是,这个O'Reilly article on Distributed TensorFlow说明如下:

在更集中的架构中,设备以渐变的形式将其输出发送到参数服务器。这些服务器收集并聚合渐变。在同步训练中,参数服务器计算模型的最新版本,并将其发送回设备。在异步训练中,参数服务器将梯度发送到本地计算新模型的设备。在这两种体系结构中,循环重复直到训练终止。

上段建议在异步培训中:

  1. 工作人员计算渐变并将其发送到参数服务器。
  2. 参数服务器将渐变广播给工作人员。
  3. 每个工作人员接收广播的梯度并应用更新规则。

我的理解是否正确?如果是,那对我来说似乎并不是非同步的,因为工作人员必须等待参数服务器广播渐变。任何解释将不胜感激。

tensorflow machine-learning
1个回答
0
投票

通常,参数服务器仅存储全局参数,工作人员直接将其渐变应用于全局参数(存储在参数服务器上)。在异步训练中,不进行广播!工人在循环中执行以下操作:

  1. 从PS获取当前的全局参数
  2. 计算梯度
  3. 将渐变应用于全局参数(将渐变应用于存储在参数服务器上的变量后,tensorflow会将渐变发送到参数服务器并将其应用于那里)

在步骤1和3之间,全局参数会发生变化,因为其他工作人员会应用其渐变。梯度的应用通常是hogwild。

在异步训练中,参数服务器将梯度发送到设备

我认为在任何异步实现中都不会发生这种情况。不知道作者试图在这里说些什么。

© www.soinside.com 2019 - 2024. All rights reserved.