[源码解析] PyTorch 分布式 Autograd (1) ---- 设计
0x00 摘要
本文以几篇PyTorch官方文档为基础来了解分布式 autograd 的设计和内部结构,在翻译时并没有逐字翻译,其中加入了自己的部分理解。分布式 autograd 后续文章的分析也会基于本文进行。
PyTorch分布式其他文章如下:
[源码解析]PyTorch如何实现前向传播(1) --- 基础类(上)
[源码解析]PyTorch如何实现前向传播(2) --- 基础类(下)
[源码解析] PyTorch如何实现前向传播(3) --- 具体实现
[源码解析] Pytorch 如何实现后向传播 (1)---- 调用引擎
[源码解析] Pytorch 如何实现后向传播 (2)---- 引擎静态结构
[源码解析] Pytorch 如何实现后向传播 (3)---- 引擎动态逻辑
[源码解析] PyTorch 如何实现后向传播 (4)---- 具体算法
[源码解析] PyTorch 分布式(1)------历史和概述
[源码解析] PyTorch 分布式(2) ----- DataParallel(上)
[源码解析] PyTorch 分布式(3) ----- DataParallel(下)
[源码解析] PyTorch 分布式(4)------分布式应用基础概念
[源码解析] PyTorch分布式(5) ------ DistributedDataParallel 总述&如何使用
[源码解析] PyTorch分布式(6) ---DistributedDataParallel -- 初始化&store
[源码解析] PyTorch 分布式(7) ----- DistributedDataParallel 之进程组
[源码解析] PyTorch 分布式(8) -------- DistributedDataParallel之论文篇
[源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化
[源码解析] PyTorch 分布式(10)------DistributedDataParallel 之 Reducer静态架构
[源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer和Join操作
[源码解析] PyTorch 分布式(12) ----- DistributedDataParallel 之 前向传播
[源码解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向传播
0x01 分布式RPC框架
本文主要以 https://pytorch.org/docs/master/rpc/distributed_autograd.html 为基准,但是原文档要求用户熟悉 Autograd 机制和分布式 RPC 框架,因为我们已经分析过 Autograd 机制,所以我们先研究一下 分布式 RPC 框架。
1.1 RPC 框架
RPC(Remote Procedure Call)是一种设计或者技术思想,而不是协议或者规范。
对于 RPC 最简单的理解就是一个节点请求另外一个节点所提供的服务,但是对于用户代码来说需要维护一个"本地调用"的感觉,即,对于远程函数调用需要像调用本地的函数一样,远程服务或者代码看起来像运行在本地。
RPC 需要解决几个问题:
- 如何通讯:即如何在调用者和服务提供者之间建立连接。
- 如何寻址:即调用者如何找到服务提供者,怎么知道其中有什么服务。
- 如何发送参数:调用者发起远程调用时候,方法的参数需要通过 TCP 等协议传输到服务器,参数如何序列化?
- 如何接受参数:服务提供者收到参数之后如何反序列化,如何调用。
- 如何返回:服务提供者调用本地提供的服务之后,如何把返回值发送给调用者。
1.2 PyTorch RPC 四大支柱
以下翻译自官方文档 https://pytorch.org/docs/master/rpc.html。
分布式 RPC 框架通过一组原语提供了多机模型训练机制以允许远程通信,以及一个更高级别的 API 来自动区分拆分到多台机器上的模型。分布式 RPC 框架使远程运行函数变得容易,支持引用远程对象而无需复制真实数据,并提供 autograd 和优化器 API 以透明地向后运行和跨 RPC 边界更新参数。这些功能可以分为四组 API。
- **远程过程调用 (RPC) ** 支持使用给定的参数在指定的worker上运行函数并获取返回值或创建对返回值的引用。有三个主要的 RPC API:
rpc_sync()
(同步)、rpc_async()
(异步)和remote()
(异步并返回对远程返回值的引用)。如果用户代码在没有返回值的情况下无法继续,请使用同步 API。否则,使用异步 API 获取 Future,并在调用者需要返回值时等待 Future。remote()
API 在需要远程创建某些内容但从不需要将其获取给调用者时很有用。想象一下driver进程设置参数服务器和训练器的情况。Driver 可以在参数服务器上创建嵌入表,然后与训练器共享嵌入表的引用,但其本身永远不会在本地使用嵌入表。在这种情况下,rpc_sync()
和rpc_async()
已不再适用,因为他们总是意味着立即或在将来把返回值发给调用者。 - 远程引用 (RRef)用作指向本地或远程对象的分布式共享指针。它可以与其他 worker 共享,并且引用计数将被透明处理。每个 RRef 只有一个所有者,并且对象只存在于该所有者之中。持有 RRef 的非所有者worker 可以通过明确请求从所有者那里获取对象的副本。当 worker 需要访问某个数据对象,但它本身既不是对象的创建者
remote()
函数的调用者也不是对象的所有者时,这很有用。分布式优化器就是此类用例的一个示例。 - Distributed Autograd将所有参与前向传播 worker的本地 autograd 引擎缝合在一起,并在后向传播期间自动联系他们以计算梯度。在进行前向传递如果需要跨越多台机器时,这尤其有用,例如分布式模型并行训练、参数服务器训练等。 有了这个特性,用户代码不再需要担心如何跨 RPC 边界发送梯度和应该以什么顺序启动本地 autograd 引擎,如果前向传递中有嵌套和相互依赖的 RPC 调用,这可能会变得非常复杂。
- 分布优化器的构造需要一个
Optimizer()
(例如,SGD()
,Adagrad()
等)和一个RRefs的参数列表。即,在每个不同的Ref所有者之上创建一个Optimizer()
实例,然后运行step()
相应更新参数。当用户进行分布式前向和后向传播时,参数和梯度将分散在多个 worker 中,因此需要对每个相关 worker 进行优化。Distributed Optimizer 将所有这些本地优化器合而为一,并提供了简洁的构造函数和step()
API。
1.3 RRef
下面我们以 https://pytorch.org/docs/master/rpc/rref.html 为基准来学习远程引用协议的基本概念和部分设计细节。
RRef 是远程参考(Remote REFerence)的缩写。 它是位于本地或远程工作worker上对象的引用,并且透明地在内部进行引用计数。 从概念上讲,它可以被视为一个分布式共享指针。 应用程序可以调用 remote()
创建 一个RRef。 每个 RRef 都被 remote()
的调用者(即所有者)所拥有,并且可以由多个用户使用。 所有者存储实际数据,并跟踪全局参考计数。 每个 RRef 可以由全局RRefId
唯一标识,该全局RRefId
在创建时由 remote()
调用者分配。
在所有者worker中,只有一个OwnerRRef
实例包含真实数据,而在用户worker之中,可以根据需要包含任意数量的UserRRefs
,UserRRef
不保存数据。当使用 RRP 时,所有者将使用全局唯一的RRefId来获取唯一的OwnerRRef实例。 在 rpc_sync()
, rpc_async()
或 remote()
调用中,所有者创建一个UserRRef
,并将其用作参数或返回值。所有者将被通知并且相应更新参考计数。 如果全局没有UserRRef
实例,并且所有者上也没有对OwnerRRef
的引用,则OwnerRRef
及其数据将被删除。
1.3.1 假设条件
RRef 协议的设计基于以下假设。
- 瞬态网络故障(Transient Network Failures):RRef 设计旨在通过重试消息来应对瞬态网络故障。 RRef不能处理节点崩溃或永久性网络分区,当这些事件发生时,应用程序应该关闭所有worker,还原到先前的checkpoint,然后恢复训练。
- 非幂等 UDF (Non-idempotent UDFs):我们假设提供给
rpc_sync()
,rpc_async()
或remote()
的用户函数(UDF)不是幂等的,因此无法重试。 但是,内部 RRef 控制消息是幂等且消息失败时可重试。 - 消息传递无序(Out of Order Message Delivery):我们不会对一对节点之间的消息传递顺序做假设,因为发送者和接收者都使用多个线程,所以无法保证首先处理哪个消息。
接下来我们只是大致讲解如何使用,具体大家可以参阅 https://pytorch.org/docs/master/rpc.html#distributed-rpc-framework。
1.3.2 同步调用
如下是同步调用API,该方法在 worker to
之上执行一个阻塞 RPC 调用来运行func
。RPC 消息的发送和接收与 Python 代码的执行并行。此方法是线程安全的。
torch.distributed.rpc.rpc_sync( to , func , args = None , kwargs = None , timeout = - 1.0 )
具体参数如下:
- to – 目标worker的name/rank/WorkerInfo。
- func (callable) – 一个可调用函数,例如 Python callables、内置运算符(例如add())和带注释的 TorchScript 函数。
- args –
func
调用的参数元组。 - kwargs –
func
调用关键字参数的字典。 - timeout – 用于此 RPC 的超时时间(以秒为单位)
返回值就是使用args
and kwargs
运行 func
的结果。
样例:
确保 MASTER_ADDR
and MASTER_PORT
已经在两个worker之上设置。
export MASTER_ADDR=localhost
export MASTER_PORT=5678
然后在两个不同的进程中运行以下代码
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
1.3.2 异步调用
如下是异步调用API,该方法在 worker to
之上执行一个非阻塞 RPC 调用来运行func
。RPC 消息的发送和接收与 Python 代码的执行并行。此方法是线程安全的。该方法立刻返回一个可以被等待的Future
。
torch.distributed.rpc.rpc_async(to, func, args=None, kwargs=None, timeout=- 1.0)
具体参数如下:
- to – 目标worker的name/rank/
WorkerInfo
。 - func (callable) – 一个可调用函数,例如 Python callables、内置运算符(例如add())和带注释的 TorchScript 函数。
- args –
func
调用的参数元组。 - kwargs – 是
func
调用关键字参数的字典。 - timeout – 用于此 RPC 的超时时间(以秒为单位)
返回一个可等待的Future
对象。完成后,可以从 对象中检索出func
的返回值。
样例:
确保 MASTER_ADDR
and MASTER_PORT
已经在两个worker之上设置。
>>> export MASTER_ADDR=localhost
>>> export MASTER_PORT=5678
然后在两个不同的进程中运行以下代码
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
>>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
>>> result = fut1.wait() + fut2.wait()
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
0x02 示例
我们接下来以 https://pytorch.org/docs/master/rpc/distributed_autograd.html 为基础进行学习。
假设您有两个节点和一个跨两个节点分区的非常简单的模型。这可以使用torch.distributed.rpc
如下实现。
分布式 autograd 背后的主要动机是在这种分布式模型上运行反向传播loss
,我们已经计算并记录了所有需要梯度的张量的梯度。
import torch
import torch.distributed.rpc as rpc
def my_add(t1, t2):
return torch.add(t1, t2)
# On worker 0:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
# Compute some loss.
loss = t5.sum()
0x03 前向传播期间的 Autograd 记录
PyTorch 在前向传播期间构建 autograd 图,该图用于执行后向传播。有关更多详细信息,请参阅 autograd 如何编码历史记录。
对于分布式 autograd,我们需要在前向传播期间跟踪所有 RPC,以确保正确执行后向传播。为此,当执行 RPC 时候,我们把 send
和recv
functions 附加到autograd图之上。
- 该
send
函数附加到 RPC 的发起源节点之上,其输出边指向 RPC 输入张量的 autograd 函数。在向后传播期间,send
函数的输入是从目标接收的,是对应recv
函数的输出。 - 该
recv
函数附加到 RPC 的接受目标节点之上,其输入从某些运算符得到,这些运算符使用输入张量在RPC接受目标上执行。在后向传播期间,recv
函数的输出梯度将被发送到源节点之上,并且作为send
方法的输入。 - 每
send-recv
对被分配一个全局唯一的autograd_message_id
以唯一地标识该send-recv
对。这对于在向后传播期间查找远程节点上的相应函数很有用。 - 对于RRef,每当我们调用
torch.distributed.rpc.RRef.to_here()
时,我们都为涉及的张量添加了一个适当的send-recv
对。
例如,这就是我们上面示例的 autograd 图的样子(为简单起见,t5.sum() 被排除在外)。
我们可以看到,send方法在前向传播中是发送者,但是在反向传播之中就是接受者。
0x04 分布式 Autograd 上下文
每个使用分布式 autograd 的前向和后向传播都被分配了一个唯一的torch.distributed.autograd.context
,并且这个上下文具有一个全局唯一的autograd_context_id
。如果有需要,在每个节点上都会创建上下文。
上下文的作用如下:
- 运行分布式反向传播的多个节点可能会在同一个张量上累积梯度并且存储在张量的
.grad
之上。在我们运行优化器之前,张量的.grad
可能累积了来自各种分布式反向传播的梯度。这类似于把torch.autograd.backward()
在本地进行多次调用。为了提供一种把每个反向传播梯度分离开的方法,在每个反向传播过程里,梯度将被累积在torch.distributed.autograd.context
之中。 - 在前向传播期间,我们在上下文中存储每个 autograd 传播的
send
和recv
函数。这确保我们在 autograd 图中保存对适当节点的引用以使其保持活动状态。除此之外,这也使得在向后传播期间很容易查找到对应的send
和recv
函数。 - 一般来说,我们也使用这个上下文来存储每个分布式 autograd 传播的一些元数据。
从用户的角度来看,autograd 上下文设置如下:
import torch.distributed.autograd as dist_autograd
with dist_autograd.context() as context_id:
loss = model.forward()
dist_autograd.backward(context_id, loss)
需要注意的是,模型的前向传播必须在分布式autograd上下文管理器中调用,因为需要一个有效的上下文来确保:所有的send
和recv
方法被存储起来,并且在所有参与节点之上执行后向传播。
0x05 分布式反向传播
在本节中,我们将概述在分布式反向传播期间准确计算依赖关系所遇到的挑战,并且也讲述几种如何执行分布式反向传播的算法(算法内部有权衡)。
5.1 计算依赖关系
首先,考虑在单台机器上运行以下代码
import torch
a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)
d = a + b
e = b * c
d.sum.().backward()
下图就是上面代码对应的 autograd 图。
作为反向传播的一部分,autograd 引擎执行的第一步是计算 autograd 图中每个节点的依赖项数量。这有助于 autograd 引擎知道图中的节点何时准备好了可以执行。括号内为数字add(1)
和mul(0)
表示依赖关系的数量。如您所见,这意味着在向后传播期间,add
节点需要 1 个输入,mul
节点不需要任何输入(换句话说,不需要执行)。本地 autograd 引擎通过从根节点(在本例中是d
)遍历图来计算这些依赖关系。
实际上,Autograd 图中的某些节点可能不会在向后传播中执行。这一事实对分布式 autograd 提出了挑战。考虑这段使用 RPC 的代码。
import torch
import torch.distributed.rpc as rpc
a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)
d = rpc.rpc_sync("worker1", torch.add, args=(a, b))
e = rpc.rpc_sync("worker1", torch.mul, args=(b, c))
loss = d.sum()
上面代码的关联 autograd 图将是:
计算此分布式 autograd 图的依赖项更具挑战性,并且需要一些开销(在计算或网络通信方面)。
对于性能敏感的应用,我们可以通过假设每个send
和recv
函数都是反向传播的有效成分来避免大量开销(大多数应用不会执行未使用的 RPC)。这简化了分布式 autograd 算法并且效率更高,但代价是应用程序需要了解这些限制。这种算法称为FAST模式算法,下面详细介绍。
在一般情况下, 作为向后传播的一部分,可能不需要每个send
和recv
函数都是有效的。为了解决这个问题,我们提出了一种SMART 模式算法,此算法将在后面的部分中描述。请注意,目前仅实现了FAST模式算法。
5.2 FAST模式算法
该算法的关键假设是:当我们运行反向传播时,每个send
函数的依赖为 1。换句话说,我们假设我们会从另一个节点通过 RPC 接收梯度。
算法如下:
- 我们从具有反向传播根的worker开始(所有根都必须是本地的)。
- 查找当前Distributed Autograd Context 的所有
send
函数 。 - 从提供的根和我们检索到的所有
send
函数开始,我们在本地计算依赖项 。 - 计算依赖项后,使用提供的根来启动本地 autograd 引擎。
- 当 autograd 引擎执行该
recv
函数时,该recv
函数通过 RPC 将输入梯度发送到适当的worker。每个recv
函数都知道目标 worker id,因为它被记录为前向传播的一部分。通过autograd_context_id
和autograd_message_id
该recv
函数被发送到远程主机。 - 当远程主机收到这个请求时,我们使用
autograd_context_id
和autograd_message_id
来查找适当的send
函数。 - 如果这是worker第一次收到对给定
autograd_context_id
的请求,它将按照上面的第 1-3 点所述在本地计算依赖项。 - 然后将在第6点接受到的
send
方法插入队列,以便在该worker的本地 autograd 引擎上执行。 - 最后,我们不是在 Tensor的
.grad
之上累积梯度,而是在每个Distributed Autograd Context之上分别累积梯度 。梯度存储在Dict[Tensor, Tensor]
之中 ,Dict[Tensor, Tensor]
基本上是从 Tensor 到其关联梯度的映射,并且可以使用 get_gradients() API检索该映射 。
例如,分布式 autograd 的完整代码如下:
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
def my_add(t1, t2):
return torch.add(t1, t2)
# On worker 0:
# Setup the autograd context. Computations that take
# part in the distributed backward pass must be within
# the distributed autograd context manager.
with dist_autograd.context() as context_id:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
# Compute some loss.
loss = t5.sum()
# Run the backward pass.
dist_autograd.backward(context_id, [loss])
# Retrieve the gradients from the context.
dist_autograd.get_gradients(context_id)
具有依赖关系的分布式 autograd 图如下(为简单起见,t5.sum() 被排除在外):
应用于上述示例的FAST 模式算法如下:
- 在
Worker 0
上,我们从根loss
和send1
开始计算依赖关系。 结果,send1
对Worker 0
的依赖数为 1,mul
对Worker 0
的依赖数为 1。 - 现在,我们在
Worker 0
上启动本地 autograd 引擎。 我们首先执行mul
函数,将其输出作为t4
的梯度,累积存储在 autograd 上下文中。 然后,我们执行recv2
,它将这些梯度发送到Worker 1
。 - 由于这是
Worker 1
第一次知道有关此反向传播的信息,因此它将进行依赖关系计算,并且相应地标记send2
,add
和recv1
的依赖性。 - 接下来,在
Worker 1
的本地autograd
引擎上将send2
插入队列,该引擎将依次执行add
和recv1
。 - 当执行
recv1
时,它将梯度发送到Worker 0
。 - 由于
Worker 0
已经计算了此向后传播的依赖性,因此它仅仅在本地将send1
插入队列并且执行。 - 最后,
t1
,t2
和t4
的梯度会累积在分布式 Autograd 上下文中。
5.3 SMART模式算法
该算法的全部细节仍在研究中,但对于总体思路,您可以参考RFC中的分布式 Autograd 算法智能模式部分 。
0x06 分布式优化器
该DistributedOptimizer
操作如下:
- 获取要优化的远程参数(
RRef
)列表。这些参数也可以是包含在本地RRef
的本地参数。 - 将一个
Optimizer
类作为本地优化器,该优化器将在所有不同的RRef
拥有者之上运行。 - 分布式优化器在每个工作节点上创建一个本地
Optimizer
实例,并且对于每一个Optimizer
保存一个RRef
。 - 当调用
torch.distributed.optim.DistributedOptimizer.step()
时,分布式优化器使用 RPC 在适当的远程工作者上远程执行所有本地优化器。必须为torch.distributed.optim.DistributedOptimizer.step()
提供一个分布式autogradcontext_id
。 本地优化器使用context_id
在相应上下文中存储梯度。 - 如果多个并发分布式优化器正在更新一个 worker 上的同一批参数,这些更新将通过锁来进行序列操作。
0x07 简单的端到端示例
综上所述,以下是一个使用分布式 autograd 和分布式优化器的简单端到端示例。如果将代码放入名为“dist_autograd_simple.py”的文件中,则可以使用以下命令运行 :MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py
import torch
import torch.multiprocessing as mp
import torch.distributed.autograd as dist_autograd
from torch.distributed import rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer
def random_tensor():
return torch.rand((3, 3), requires_grad=True)
def _run_process(rank, dst_rank, world_size):
name = "worker{}".format(rank)
dst_name = "worker{}".format(dst_rank)
# Initialize RPC.
rpc.init_rpc(
name=name,
rank=rank,
world_size=world_size
)
# Use a distributed autograd context.
with dist_autograd.context() as context_id:
# Forward pass (create references on remote nodes).
rref1 = rpc.remote(dst_name, random_tensor)
rref2 = rpc.remote(dst_name, random_tensor)
loss = rref1.to_here() + rref2.to_here()
# Backward pass (run distributed autograd).
dist_autograd.backward(context_id, [loss.sum()])
# Build DistributedOptimizer.
dist_optim = DistributedOptimizer(
optim.SGD,
[rref1, rref2],
lr=0.05,
)
# Run the distributed optimizer step.
dist_optim.step(context_id)
def run_process(rank, world_size):
dst_rank = (rank + 1) % world_size
_run_process(rank, dst_rank, world_size)
rpc.shutdown()
if __name__ == '__main__':
# Run world_size workers
world_size = 2
mp.spawn(run_process, args=(world_size,), nprocs=world_size)
0xFF 参考
https://pytorch.org/docs/master/rpc/distributed_autograd.html#distributed-autograd-design
https://pytorch.org/docs/master/rpc.html#distributed-autograd-framework
https://pytorch.org/docs/master/rpc/rref.html
https://pytorch.org/docs/master/rpc.html#distributed-rpc-framework