• [源码解析] PyTorch 分布式之弹性训练(5)Rendezvous 引擎


    [源码解析] PyTorch 分布式之弹性训练(5)---Rendezvous 引擎

    0x00 摘要

    在前面的文章之中,我们已经学习了PyTorch 分布式的基本模块,介绍了官方的几个例子,我们接下来会介绍PyTorch的弹性训练,本文是第五篇,看看Rendezvous 的内部引擎,比如如何处理节点加入,节点离开,等待,心跳等等。

    弹性训练系列文章如下:

    [源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路

    [源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程

    [源码解析] PyTorch 分布式之弹性训练(3)---代理

    [源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑

    0x01 前言

    1.1 总体系统

    弹性训练可以理解为在 Rendezvous 基础之上的一个运行系统。

    • Agent 偏重具体节点上的逻辑

      • Agent 负责具体业务逻辑相关操作,比如启动进程执行用户程序,监控用户程序运行情况,如果有异常就通知 Rendezvous。
      • Agent 是一个 worker manager,负责启动/管理 workers 进程,组成一个 worker group,监控 workers 运行状态,捕获失效 workers,如果有故障/新加入worker,则重启 worker group。
      • Agent负责维护 WORLD_SIZE 以及 RANK 信息。用户不需要再手动提供,Agent会自动处理这些。
      • Agent 是具体节点上的后台进程,是独立个体。Agent自己无法实现整体上的弹性训练,所以需要一个机制来完成 worker 之间的相互发现,变更同步等等(WORLD_SIZE 和 RANK 这些信息其实也需要多个节点同步才能确定),这就是下面的 Rendezvous 概念。
    • Rendezvous 负责

      集群逻辑

      ,保证节点之间对于""有哪些节点参与训练"达成强一致共识。

      • 每一个 Agent 内部包括一个 Rendezvous handler,这些 handler 总体上构成了一个 Rendezvous 集群,从而构成了一个 Agent 集群。
      • Rendezvous 完成之后,会创建一个共享键值存储(shared key-value store),这个store实现了一个torch.distributed.Store API。此存储仅由已完成Rendezvous的成员共享,它旨在让Torch Distributed Elastic在初始化作业过程之中交换控制和数据信息。
      • Rendezvous 负责在每个agent之上维护当前 group 所有相关信息。每个 agent 之上有一个 rendezvous,它们会互相通信,总体维护一套信息,这些信息存储在上面提到的Store 之中。
      • Rendezvous 负责集群逻辑相关,比如新加入节点,移除节点,分配rank等等。

    1.2 Rendezvous

    目前为止,Rendezvous 信息如下,DynamicRendezvousHandler 属于动态逻辑,其中,_RendezvousStateHolder 是状态等元信息存储(静态结构),大家会发现图中还有一个 _RendezvousOpExecutor 没有介绍,这就是运行时引擎,所以我们本文看看 _RendezvousOpExecutor 如何处理。

    +-----------------------------+      +------------------------------------------------+
    | LocalElasticAgent           |      | WorkerSpec                                     |
    |                             |      |                                                |
    | +------------------------+  |      |   rdzv_handler = {DynamicRendezvousHandler} -------+
    | |WorkerGroup             |  |      |                                                |   |
    | |            spec +--------------> |   entry = worker_fn                            |   |
    | |            workers     |  |      |                                                |   |
    | |            store       |  |      |   role = {str} 'trainer'                       |   |
    | |            group_rank  |  |      |                                                |   |
    | |       group_world_size |  |      +------------------------------------------------+   |
    | |                        |  |                                                           |
    | +------------------------+  |                                                           |
    |                             |                                                           |
    | rdzv_run_id                 |                                                           |
    | store                       |            +-----------------------------------------+    |
    |                             |            |DynamicRendezvousHandler                 |    |
    +-----------------------------+            |                                         |    |
                                               |                                         |    |
                                               |   _settings: RendezvousSettings         | <--+
                                               |                                         |
                                               |   _store: Store                         |
                                               |                                         |
                                               |   _state_holder: _RendezvousStateHolder |
                                               |                                         |
                                               |   _op_executor: _RendezvousOpExecutor   |
                                               |                                         |
                                               +-----------------------------------------+
    

    1.3 解耦

    _RendezvousOpExecutor 把功能分割解耦:

    • 业务逻辑被抽象成为一系列算子,比如 _RendevzousJoinOp
    • Rendezvous 内部维护了一套由业务函数组成的状态机,比如函数 _add_to_participants 用来添加参与者。
    • _RendezvousOpExecutor 引擎来执行各种算子,依据算子结果,得到一个 Action,再利用 Action 调用业务函数进行操作。

    本文主要介绍C10d 后端对应的 Rendezvous 引擎。

    0x02 引擎实现

    2.1 基类

    _RendezvousOpExecutor 是引擎的基类,只是定义了run这个虚函数。

    class _RendezvousOpExecutor(ABC):
        """Executes rendezvous operations."""
    
        @abstractmethod
        def run(
            self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
        ) -> None:
            """Executes a rendezvous operation.
    
            An operation is run inside a state machine and is expected to transition
            the rendezvous from one state to another.
    
            Args:
                state_handler:
                    A callable that is expected to return the next state transition
                    action based on the current state of the rendezvous.
                deadline:
                    The time, in seconds, at which the operation will be considered
                    timed-out.
            """
    

    这里用到了 _RendezvousContext,其作用是把 Rendezvous 的各种信息封装了起来,提供给操作引擎。这里就有了 _RendezvousState 和 RendezvousSettings 的使用。

    class _RendezvousContext:
        """Holds the context of the rendezvous.
    
        Attributes:
            node:
                The node descriptor associated with the current rendezvous handler
                instance.
            state:
                The current state of the rendezvous.
            settings:
                The rendezvous settings.
        """
    
        node: _NodeDesc
        state: _RendezvousState
        settings: RendezvousSettings
    
        def __init__(
            self, node: _NodeDesc, state: _RendezvousState, settings: RendezvousSettings
        ) -> None:
            self.node = node
            self.state = state
            self.settings = settings
    

    2.2 分布式操作引擎

    _DistributedRendezvousOpExecutor 拓展了 _RendezvousOpExecutor,是 ElasticTorch 的实际执行者。类似于 Looper,负责消息分发,调用业务,状态维护

    2.2.1 定义

    与其基类相比,_DistributedRendezvousOpExecutor 加入了比如节点信息,状态,配置这样的成员变量。

    class _DistributedRendezvousOpExecutor(_RendezvousOpExecutor):
        """Executes rendezvous operations using a shared state.
    
        Args:
            node:
                The node descriptor associated with the current rendezvous handler
                instance.
            state_holder:
                The ``RendezvousStateHolder`` to use to sync the rendezvous state
                with other nodes.
            settings:
                The rendezvous settings.
        """
    
        _node: _NodeDesc
        _state: _RendezvousState
        _state_holder: _RendezvousStateHolder
        _settings: RendezvousSettings
    
        def __init__(
            self,
            node: _NodeDesc,
            state_holder: _RendezvousStateHolder,
            settings: RendezvousSettings,
        ) -> None:
            self._node = node
            self._state_holder = state_holder
            self._settings = settings
    

    逻辑如下:

    +---------------------------------------------------------------+
    | _DistributedRendezvousOpExecutor                              |
    |                                                               |
    |                     +------------------------+                |
    |        _state +---> | _RendezvousState       |                |
    |                     |                        |                |
    |                     |       participants     |                |
    |                     |       wait_list        |                |
    |                     |       last_heartbeats  |                |
    |                     |       deadline         |                |
    |                     +------------------------+                |
    |                                                               |
    |                     +-------------------------+               |
    |      _settings +--> | RendezvousSettings      |               |
    |                     |                         |               |
    |                     +-------------------------+               |
    |                                                               |
    |                     +--------------------------------------+  |
    | _state_holder +---> | _BackendRendezvousStateHolder        |  |
    |                     |                                      |  |
    |                     |        _backend: RendezvousBackend   |  |
    |                     |        _state: _RendezvousState      |  |
    |                     |        _settings: RendezvousSettings |  |
    |                     |                                      |  |
    |                     +--------------------------------------+  |
    |                     +--------------------------------------+  |
    |                     | _NodeDesc                            |  |
    |     _node +-------> |              fqdn: str               |  |
    |                     |              pid: int                |  |
    |                     |              local_id: int           |  |
    |                     |                                      |  |
    |                     +--------------------------------------+  |
    +---------------------------------------------------------------+
    

    2.2.2 调用

    我们举出几个例子来看看如何调用引擎,可以看到都是先设置算子,然后调用引擎的run函数。

    2.2.2.1 _RendezvousKeepAliveOp
    def _keep_alive(self) -> None:
        self._heartbeat_lock.acquire()
        op = _RendezvousKeepAliveOp() # 设置算子
        deadline = self._get_deadline(self._settings.timeout.heartbeat)
        self._op_executor.run(op, deadline) # 调用
    
    2.2.2.2 _RendezvousCloseOp
    def _close(self) -> None:
        op = _RendezvousCloseOp() # 设置算子
        deadline = self._get_deadline(self._settings.timeout.close)
        self._op_executor.run(op, deadline) # 调用
    
    2.2.2.3 _RendezvousJoinOp
    def next_rendezvous(self) -> Tuple[Store, int, int]:
        """See base class."""
    
        self._stop_heartbeats()
    
        # Delay the execution for a small random amount of time if this is our
        # first run. This will slightly skew the rendezvous attempts across the
        # nodes and reduce the load on the backend.
        if self._state_holder.state.round == 0:
            _delay(seconds=(0, 0.3))
    
        exit_op = _RendezvousExitOp() # 设置算子
        join_op = _RendezvousJoinOp() # 设置算子
    
        deadline = self._get_deadline(self._settings.timeout.join)
    
        self._op_executor.run(exit_op, deadline) # 这里会进行调用
        self._op_executor.run(join_op, deadline) # 调用
    
        self._start_heartbeats()
    
        rank, world_size = self._get_world()
        store = self._get_store()
    
        return store, rank, world_size
    

    2.2.3 功能

    _DistributedRendezvousOpExecutor 之中,run 函数实现了基础逻辑,就是依据 action 类型进行各种操作。

    2.2.3.1 主体循环

    run 具体代码如下:

        def run(
            self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
        ) -> None:
            """See base class."""
            action = None
    
            while action != _Action.FINISH: # 循环,一直到获得一个FINISH action 为止
                # Reads or writes the latest rendezvous state shared by all nodes in
                # the rendezvous. Note that our local changes might get overridden
                # by another node if that node synced its changes before us.
                
                # 这里很重要,在所有node之间做信息同步
                has_set = self._state_holder.sync() # 因为最新状态在 rendezvous。
    
                self._state = self._state_holder.state
    
                ctx = _RendezvousContext(self._node, self._state, self._settings)
    
                # Determine the next action to take based on the current state of
                # the rendezvous.
                action = state_handler(ctx, deadline) # 决定下一个操作,state_handler 就是算子
    
                if action == _Action.FINISH:
                    continue
    
                if action == _Action.ERROR_CLOSED:
                    raise RendezvousClosedError()
    
                if action == _Action.ERROR_TIMEOUT:
                    raise RendezvousTimeoutError()
    
                if action == _Action.SYNC:
                    # Delay the execution by one second to avoid overloading the
                    # backend if we are asked to poll for state changes.
                    _delay(seconds=1)
                else:
                    if action == _Action.KEEP_ALIVE:
                        self._keep_alive()
                    elif action == _Action.ADD_TO_PARTICIPANTS:
                        self._add_to_participants()
                    elif action == _Action.ADD_TO_WAIT_LIST:
                        self._add_to_wait_list()
                    elif action == _Action.REMOVE_FROM_PARTICIPANTS:
                        self._remove_from_participants()
                    elif action == _Action.REMOVE_FROM_WAIT_LIST:
                        self._remove_from_wait_list()
                    elif action == _Action.MARK_RENDEZVOUS_COMPLETE:
                        self._mark_rendezvous_complete()
                    elif action == _Action.MARK_RENDEZVOUS_CLOSED:
                        self._mark_rendezvous_closed()
    
                    # Attempt to sync our changes back to other nodes.
                    self._state_holder.mark_dirty()
    

    具体如下图。

    +-----------------------------------------+                          +---------------------------------------------------------------+
    |DynamicRendezvousHandler                 |                          | _DistributedRendezvousOpExecutor                              |
    |                                         |                          |                                                               |
    |                                         |                          |                     +------------------------+                |
    |   _settings: RendezvousSettings         |                          |        _state +---> | _RendezvousState       |                |
    |                                         |                          |                     |                        |                |
    |                                         |                          |                     |       participants     |                |
    |   _store: Store                         |                          |                     |       wait_list        |                |
    |                                         |                          |                     |       last_heartbeats  |                |
    |                                         |                          |                     |       deadline         |                |
    |   _state_holder: _RendezvousStateHolder |                          |                     +------------------------+                |
    |                                         | run(_RendezvousJoinOp()) |                     +-------------------------+               |
    |                                         |                          |      _settings +--> | RendezvousSettings      |               |
    |   _op_executor  +------------------------------------------------> |                     |                         |               |
    |                                         |                          |                     +-------------------------+               |
    |                                         |                          |                     +--------------------------------------+  |
    +-----------------------------------------+                          | _state_holder +---> | _BackendRendezvousStateHolder        |  |
                                                                         |                     |                                      |  |
                                                                         |                     |        _backend: RendezvousBackend   |  |
                                                                         |                     |        _state: _RendezvousState      |  |
                                                                         |                     |        _settings: RendezvousSettings |  |
                                                                         |                     |                                      |  |
                                                                         |                     +--------------------------------------+  |
                                                                         |                     +--------------------------------------+  |
                                                                         |                     | _NodeDesc                            |  |
                                                                         |     _node +-------> |              fqdn: str               |  |
                                                                         |                     |              pid: int                |  |
                                                                         |                     |              local_id: int           |  |
                                                                         |                     |                                      |  |
                                                                         |                     +--------------------------------------+  |
                                                                         +---------------------------------------------------------------+
    

    手机如下:

    2.2.3.2 同步

    在 run 函数之中,需要注意的是:在执行各种算子操作之前,会调用 self._state_holder.sync() 在各个 worker 之间进行一个状态同步,达成共识 (consensus)

    def sync(self) -> Optional[bool]:
        """See base class."""
        state_bits: Optional[bytes] = None
        token = None
        has_set: Optional[bool]
    
        if self._dirty: # 如果本node状态变化了
            has_set = False
            state_bits = pickle.dumps(self._state)
            # 把自己的状态设置到backend之中
            set_response = self._backend.set_state(state_bits, self._token)
            if set_response is not None:
                state_bits, token, has_set = set_response
        else: # 自己没变化,只能从后端获取
            has_set = None
            if self._cache_duration > 0:
                # Avoid overloading the backend if we are asked to retrieve the
                # state repeatedly. Try to serve the cached state.
                if self._last_sync_time >= max(time.monotonic() - self._cache_duration, 0):
                    return None
            get_response = self._backend.get_state() # 从backend获取其他节点最新状态
            if get_response is not None:
                state_bits, token = get_response
    
        if state_bits is not None:
            try:
                self._state = pickle.loads(state_bits) # 用后端状态更新本身的状态
            except pickle.PickleError as exc:
                raise RendezvousStateError(
                    "The rendezvous state is corrupt. See inner exception for details."
                ) from exc
        else:
            self._state = _RendezvousState()
    
        if has_set and self._dead_nodes and log.isEnabledFor(logging.DEBUG):
            node_list = ", ".join(f"'{dead_node}'" for dead_node in self._dead_nodes)
            msg = (
                f"As part of the sync operation the node(s) {node_list} have been removed from the "
                f"rendezvous '{self._settings.run_id}' since they had no heartbeat."
            )
            self._record(message=msg)
    
        self._token = token
        self._dirty = False
        self._last_sync_time = time.monotonic()
        self._sanitize()
    
        return has_set
    
    后端

    torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py 之中是对应后端代码。

    后端这里使用 store 作为一个集中式存储,是master。每个 node 是 client,会去master更新自己状态,并且获取其他node状态。这样所有node就会互通有无,达成共识。这里也会定期删除不更新元数据的clients。

    get_state 就是简单的从 store 提取。

    def get_state(self) -> Optional[Tuple[bytes, Token]]:
        """See base class."""
        base64_state: bytes = self._call_store("get", self._key)
    
        return self._decode_state(base64_state)
    

    set_state 会做一个compare set,其返回new state和是否更新了state。

    def set_state(
        self, state: bytes, token: Optional[Token] = None
    ) -> Optional[Tuple[bytes, Token, bool]]:
        """See base class."""
        base64_state_str: str = b64encode(state).decode()
    
        if token:
            # Shortcut if we know for sure that the token is not valid.
            if not isinstance(token, bytes):
                result = self.get_state()
                if result is not None:
                    tmp = *result, False
                    # Python 3.6 does not support tuple unpacking in return
                    # statements.
                    return tmp
                return None
    
            token = token.decode()
        else:
            token = self._NULL_SENTINEL
    
        base64_state: bytes = self._call_store("compare_set", self._key, token, base64_state_str)
    
        state_token_pair = self._decode_state(base64_state)
        if state_token_pair is None:
            return None
    
        new_state, new_token = state_token_pair
    
        # C10d Store's compare_set method does not offer an easy way to find out
        # whether our write attempt was successful. As a brute-force solution we
        # perform a bitwise comparison of our local state and the remote state.
        return new_state, new_token, new_state == state
    
    _sanitize

    _sanitize 方法用来依据其他节点消息做处理,比如清理故障节点。即,如果上一次的心跳时间超过了一定阈值范围,则会把这些节点标记为dead_node,并且从 participant或者wait list中清除这些节点。

    def _sanitize(self) -> None:
        state = self._state
    
        expire_time = datetime.utcnow() - (
            self._settings.keep_alive_interval * self._settings.keep_alive_max_attempt
        )
    
        # Filter out the dead nodes.
        self._dead_nodes = [
            node
            for node, last_heartbeat in state.last_heartbeats.items()
            if last_heartbeat < expire_time
        ]
    
        participant_removed = False
    
        for dead_node in self._dead_nodes:
            del state.last_heartbeats[dead_node] # 移除故障节点
    
            try:
                del state.participants[dead_node] # 移除故障节点
    
                participant_removed = True
            except KeyError:
                pass
    
            try:
                state.wait_list.remove(dead_node) # 移除故障节点
            except KeyError:
                pass
    
        if participant_removed:
            # Common epilogue shared with the _remove_from_participants()
            # function of _DistributedRendezvousOpExecutor.
            _remove_participant_epilogue(state, self._settings)
    

    介绍完毕如何运行引擎,我们接下来看看具体算子。

    0x03 算子

    _RendezvousOpExecutor 引擎的业务逻辑被分成两层:用户操作 和 内部业务逻辑。用户操作和内部业务机制之间被解耦。

    • 用户操作被分成各种算子,包括:心跳,Join,关闭,结束。比如Join 算子就是 _RendevzousJoinOp

    • 内部业务逻辑被分成各种业务函数,比如 _add_to_participants 方法从等待列表中移除节点,往 participants 加入这个节点。

    • 算子和内部业务逻辑并不是一一对应,需要一个类似状态机的机制来控制。

      • 比如,心跳操作算子的结果可能是:超时/keep alive/正常结束,所以应该根据这个结果调用不同的内部业务函数。这种对应关系逻辑就是通过 Action 来完成的
      • 各种算子联合起来,聚合成了一个状态机。
      • 算子内部就是生成各种 Action,决定了状态机的下一步操作。
    • 引擎内部就是根据 Action 来执行具体业务逻辑,或者可以说,是通过 Action 进行解耦。

    具体如下,引擎从逻辑上可以分成三层:最上面是算子层,中间是 Action 层,下面是业务函数层。

    +-----------------------------------------------------------------------------------------+
    |                                                                                         |
    | _RendezvousKeepAliveOp    _RendezvousCloseOp    _RendezvousExitOp    _RendezvousJoinOp  |
    |                                                                                         |
    +-------------+---------------------+--------------------+------------------+-------------+
                  |                     |                    |                  |
                  |                     |                    |                  |
                  |                     |                    |                  |
                  |                     |                    |                  |
                  v                     v                    v                  v
    
    +-----------------------------------------------------------------------------------------+
    |                                                                                         |
    | KEEP_ALIVE   ADD_TO_PARTICIPANTS   ADD_TO_WAIT_LIST   REMOVE_FROM_WAIT_LIST   ......    |
    |                                                                                         |
    +-------------+----------+----------+----------+---------+---------+---------+------------+
                  |          |          |          |         |         |         |
                  |          |          |          |         |         |         |
                  |          |          |          |         |         |         |
                  |          |          |          |         |         |         |
                  v          v          v          v         v         v         v
    
    +-----------------------------------------------------------------------------------------+
    |                                                                                         |
    | _add_to_participants    _remove_from_participants     _add_to_wait_list        ......   |
    |                                                                                         |
    |                                                                                         |
    +-----------------------------------------------------------------------------------------+
    

    我们逐一解析。

    3.1 操作

    先来解析中间层 Action,看看有多少 Action。基于 rendezvous 的状态,引擎的actions具体如下。代码位于 torch/distributed/elastic/rendezvous/dynamic_rendezvous.py

    class _Action(Enum):
        """Specifies the possible actions based on the state of the rendezvous."""
    
        KEEP_ALIVE = 1
        ADD_TO_PARTICIPANTS = 2
        ADD_TO_WAIT_LIST = 3
        REMOVE_FROM_PARTICIPANTS = 4
        REMOVE_FROM_WAIT_LIST = 5
        MARK_RENDEZVOUS_COMPLETE = 6
        MARK_RENDEZVOUS_CLOSED = 7
        SYNC = 8
        ERROR_CLOSED = 9
        ERROR_TIMEOUT = 10
        FINISH = 11
    

    3.2 算子

    引擎之中实现了一些算子,基本上,一个操作对应一个算子,我们给出几个操作算子的例子,算子就是依据rendezvous的状态来设置操作类型

    3.2.1 心跳

    3.2.1.1 检查心跳

    _RendezvousKeepAliveOp 的作用是:依据当前状态和时间来确定下一步Action。主要是定期检查本Node是否故障。

    class _RendezvousKeepAliveOp:
        """Represents a rendezvous keep-alive update operation."""
    
        def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
            if _should_keep_alive(ctx):
                if time.monotonic() > deadline:
                    return _Action.ERROR_TIMEOUT
                return _Action.KEEP_ALIVE
            return _Action.FINISH
    

    _should_keep_alive 方法为:

    def _should_keep_alive(ctx: _RendezvousContext) -> bool:
        """Determines whether a keep-alive heartbeat should be sent."""
        try:
            last_heartbeat = ctx.state.last_heartbeats[ctx.node]
        except KeyError:
            return False
    
        return last_heartbeat <= datetime.utcnow() - ctx.settings.keep_alive_interval
    
    3.2.1.2 定期调用

    这里要注意的是,因为做任何算子之前,都要调用 sync 操作,而 sync 会在 node 之间同步状态,因为心跳是定期的,所以同步状态也是定期的。

    DynamicRendezvousHandler 之中会启动一个timer,定期调用_keep_alive_weak方法。

    def _start_heartbeats(self) -> None:
        self._keep_alive_timer = _PeriodicTimer(
            self._settings.keep_alive_interval, self._keep_alive_weak, weakref.ref(self)
        )
    
        self._keep_alive_timer.set_name(f"RendezvousKeepAliveTimer_{self._this_node.local_id}")
        self._keep_alive_timer.start()
    

    其次,_keep_alive_weak 会调用 self._keep_alive()

    @staticmethod
    def _keep_alive_weak(weak_self) -> None:
        self = weak_self()
        if self is not None:
            self._keep_alive()
    

    _keep_alive 会调用 _RendezvousKeepAliveOp。

    def _keep_alive(self) -> None:
        self._heartbeat_lock.acquire()
        op = _RendezvousKeepAliveOp()
        deadline = self._get_deadline(self._settings.timeout.heartbeat)
    
        try:
            self._op_executor.run(op, deadline)
            msg = (
                f"The node '{self._this_node}' has sent a keep-alive heartbeat to the rendezvous "
                f"'{self._settings.run_id}'."
            )
            self._record(message=msg)
            log.debug(msg)
        except RendezvousError as ex:
            msg = (
                f"The node '{self._this_node}' has failed to send a keep-alive heartbeat to the "
                f"rendezvous '{self._settings.run_id}' due to an error of type {type(ex).__name__}."
            )
            self._record(message=msg, node_state=NodeState.FAILED)
        finally:
            self._heartbeat_lock.release()
    
    3.2.1.2 设置心跳

    另外,_DistributedRendezvousOpExecutor 有一个 _keep_alive 同名函数,是用来实现内部逻辑,我们后续会讲到。

    3.2.2 关闭

    _RendezvousCloseOp 会依据当前状态和时间来确定下一步Action。

    class _RendezvousCloseOp:
        """Represents a rendezvous close operation."""
    
        def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
            if ctx.state.closed:
                return _Action.FINISH
            if time.monotonic() > deadline:
                return _Action.ERROR_TIMEOUT
            return _Action.MARK_RENDEZVOUS_CLOSED
    

    3.2.3 结束

    _RendezvousExitOp 依据当前状态和时间来确定下一步Action。如果本Node不在participants之中,不处理。否则返回一个从 participants 列表删除的下一步Action。如果超时则返回对应Action。

    class _RendezvousExitOp:
        """Represents a rendezvous exit operation."""
    
        def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
            if ctx.node in ctx.state.participants:
                if time.monotonic() > deadline:
                    return _Action.ERROR_TIMEOUT
                return _Action.REMOVE_FROM_PARTICIPANTS
            return _Action.FINISH
    

    3.2.4 Join

    _RendezvousJoinOp 这里依据系统状态不同,做不同处理,比如试图把本Node加入到participant,或者 waiting list,或者继续等待,具体可以参见代码注释。

    • 从上下文之中提取 _RendezvousState 状态,把结果存放在 state 之中。
    • 如果状态是closed,则说明此时rendezvous已经结束,则返回_Action.ERROR_CLOSED。
    • 看看是不是参与者,把结果存放在is_participant。
    • 如果状态已经结束,且本节点已经是参与者,则说明 rendezvous 可以结束,返回 _Action.FINISH。
    • 获取当前时间 now。
    • 如果 now > deadline,说明已经超时。
      • 如果还有时间做 rollback,说明本节点要返回之前的状态。
        • 如果本节点已经是参与者,说明此时总节点数目没有达到 min,虽然已经是参与者,但是需要从参与者列表移除,所以返回 _Action.REMOVE_FROM_PARTICIPANTS。
        • 如果本节点在等待列表之中,说明此时总节点数目没有达到 max,虽然在等待列表之中,但是需要从等待列表移除,所以返回_Action.REMOVE_FROM_WAIT_LIST。
      • 否则返回_Action.ERROR_TIMEOUT。
    • 否则没有超时,继续处理。
      • 如果state.complete 并且本节点不是参与者(如果节点是参与者,前面已经处理过了),说明rendezvous 已经结束,如果还没有达到最大节点数目,并且当前node不在等待列表之中,就需要添加到等待节点列表,等待下次监控周期到的时候,重新做rendezvous,就可以把等待列表中的节点加入到参与列表之中。所以返回_Action.ADD_TO_WAIT_LIST。
      • 如果本节点是参与者并且state不是complete状态(如果是complete状态,前面已经处理过了),如果已经达到了最小节点数 & 已经超时了,则说明rendezvous 已经结束,则返回_Action.MARK_RENDEZVOUS_COMPLETE。
      • 否则说明没结束,本节点也不是参与者,则直接加入到参与者列表,返回_Action.ADD_TO_PARTICIPANTS。
    • 如果需要保持心跳,就返回 _Action.KEEP_ALIVE。
    • 否则返回_Action.SYNC。
    class _RendezvousJoinOp:
        """Represents a rendezvous join operation."""
    
        def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
            state = ctx.state # 从上下文之中提取 _RendezvousState 状态
    
            # A closed rendezvous means that it no longer accepts new nodes.
            if state.closed:
                return _Action.ERROR_CLOSED # 如果已经结束,就返回 _Action.ERROR_CLOSED
    
            is_participant = ctx.node in state.participants # 看看是不是参与者
    
            # If we are part of the rendezvous and it is already complete there is
            # no further action to take.
            if state.complete and is_participant: # 如果是参与者且状态是结束,就返回 _Action.FINISH
                return _Action.FINISH
    
            now = time.monotonic()
            if now > deadline: # 如果已经超时
                rollback_period = 5  # 5 seconds
    
                # If we still have time to rollback (a short period on top of the
                # operation deadline), try to remove ourself from the rendezvous.
                # It is okay if we can't though as our keep-alive will eventually
                # expire.
                if now <= deadline + rollback_period: # 如果还有时间来 rollback
                    # If we are part of the rendezvous, it means we couldn't find
                    # enough participants to complete it on time.
                    if is_participant: # 此时尚未达到min,虽然已经是参与者,但是需要移除
                        return _Action.REMOVE_FROM_PARTICIPANTS # 需要从参与者列表移除
                    # If we are in the wait list, it means we couldn't wait till the
                    # next round of the rendezvous.
                    if ctx.node in state.wait_list: # 此时已经达到 max,虽然已经在等待列表之中,需要移除
                        return _Action.REMOVE_FROM_WAIT_LIST # 需要从等待列表移除
                return _Action.ERROR_TIMEOUT # 返回超时
    
            if state.complete: # 如果 rendezvous 已经结束
                # If we are here, it means we are not part of the rendezvous. In
                # case the rendezvous has capacity for additional participants add
                # ourself to the wait list for the next round.
                if len(state.participants) < ctx.settings.max_nodes: # 如果还没有达到最大节点数
                    if ctx.node not in state.wait_list: # 如果当前node不在等待列表之中
                        return _Action.ADD_TO_WAIT_LIST # 就加入到等待列表,发送一个等待action
            elif is_participant: # 如果已经在参与者列表
                # If the rendezvous has enough number of participants including us,
                # check whether we have passed the rendezvous deadline. If yes,
                # complete it.
                if len(state.participants) >= ctx.settings.min_nodes: # 如果达到了最小节点数
                    if cast(datetime, state.deadline) < datetime.utcnow(): # 如果达到了超时
                        return _Action.MARK_RENDEZVOUS_COMPLETE # 标示 rendezvous 已经结束
            else: # 否则就直接加入到参与者
                # The rendezvous is not complete yet and we are not part of it. Try
                # to join.
                return _Action.ADD_TO_PARTICIPANTS
    
            if _should_keep_alive(ctx): # 如果需要保持心跳,就返回 _Action.KEEP_ALIVE
                return _Action.KEEP_ALIVE
    
            # At this point either the rendezvous is not complete, but we are part
            # of it, which means we have to wait for other participants to join; or
            # the rendezvous is complete, but we are not part of it, which means we
            # have to wait for the next round.
            return _Action.SYNC # 否则返回同步状态 _Action.SYNC
    

    具体逻辑如下:

                               state.closed
                            +-------------------------->   _Action.ERROR_CLOSED
                            |
                            |
                            |  complete & participant
                            +-------------------------->   _Action.FINISH
                            |
                            |
                            |  timeout & participant
                            +-------------------------->   _Action.REMOVE_FROM_PARTICIPANTS
                            |
                            |
                            |  timeout & wait
                            +-------------------------->   _Action.REMOVE_FROM_WAIT_LIST
                            |
    +-------------------+   |
    |                   |   |  timeout
    | _RendezvousJoinOp +------------------------------>   _Action.ERROR_TIMEOUT
    |                   |   |
    +-------------------+   |  complete & < max & not wait
                            |
                            +-------------------------->   _Action.ADD_TO_WAIT_LIST
                            |
                            |  complete & participant & > min & deadline
                            |
                            +-------------------------->   _Action.MARK_RENDEZVOUS_COMPLETE
                            |
                            |  not complete & not participant
                            |
                            +-------------------------->   _Action.ADD_TO_PARTICIPANTS
                            |
                            |  _should_keep_alive
                            |
                            +-------------------------->   _Action.KEEP_ALIVE
                            |
                            |  else
                            |
                            +-------------------------->   _Action.SYNC
    
    

    以下是源码之中 ETCD 后端 Rendezvous 状态描述图,我们可以大致参考比对 c10d的状态。

    可见,etcd 后端的Join可以分为4个阶段:

    • setup 阶段,会往固定目录写一个值,这是一个排他锁,如果写失败,说明目前正有一个 rendezvous 过程在进行中。
    • join(joinable) 阶段。如果写值成功,则进入join 阶段。如果在等待时间结束或者参与训练的节点达到了最大值,则进入 frozen 阶段。
    • frozen(confirm)阶段。需要所有节点都确认,进入最后的 final 阶段。
    • final 阶段。分配rank,RANK 0 的实例成为 master。

    仿照上图,我们把 c10d 拓展如下。

          +
          |
          |
          v
    +-----+------+
    |            |
    |   closed   +---------------> ERROR_CLOSED
    |            |
    +-----+------+
          |
          |
          v
    +-----+------+  is_participant
    |            |
    |  complete  +---------------> FINISH
    |            |
    +-----+------+
          |                                                                                 is_participant
          |
          v                                                                                +----> REMOVE_FROM_PARTICIPANTS
    +-----+-------+  now > deadline  +-----------+    now < rollback     +-----------+     |
    |             |                  |           |                       |           |     |
    |    join     +----------------> |  timeout  +---------------------->+ rollback  +-----+
    |             |                  |           |                       |           |     |
    +-----+-------+                  +----+------+                       +-----------+     |
          |                               |                                                | in state.wait_list
          |                               |    now > rollback                              |
          |  now < deadline               |                                                +----> REMOVE_FROM_WAIT_LIST
          |                               +---------->  ERROR_TIMEOUT
          |
          |   complete && not is_participant && < max && not in state.wait_list
          |
          +------------------------------------------------------------------>  ADD_TO_WAIT_LIST
          |
          |   not complete && is_participant && > min && > deadline
          |
          +------------------------------------------------------------------>  MARK_RENDEZVOUS_COMPLETE
          |
          |   not complete && not is_participant
          |
          +----------------------------------------->  ADD_TO_PARTICIPANTS
          |
          |   _should_keep_alive
          |
          +--------------------------->  KEEP_ALIVE
          |
          |
          v
         SYNC
    
    

    手机如下:

    0x04 业务操作

    _DistributedRendezvousOpExecutor.run 的内部就是依据 action 选择不同的业务函数来执行。

                if action == _Action.KEEP_ALIVE:
                    self._keep_alive()
                elif action == _Action.ADD_TO_PARTICIPANTS:
                    self._add_to_participants()
                elif action == _Action.ADD_TO_WAIT_LIST:
                    self._add_to_wait_list()
                elif action == _Action.REMOVE_FROM_PARTICIPANTS:
                    self._remove_from_participants()
                elif action == _Action.REMOVE_FROM_WAIT_LIST:
                    self._remove_from_wait_list()
                elif action == _Action.MARK_RENDEZVOUS_COMPLETE:
                    self._mark_rendezvous_complete()
                elif action == _Action.MARK_RENDEZVOUS_CLOSED:
                    self._mark_rendezvous_closed()
    

    我们接下来就看看具体这些内部函数逻辑。

    4.1 加入参与者

    接受到 ADD_TO_PARTICIPANTS 之后,调用 _add_to_participants 从等待列表中移除节点,往 participants 加入这个节点。

        def _add_to_participants(self) -> None:
    
            state = self._state
    
            try:
                state.wait_list.remove(self._node)
            except KeyError:
                pass
    
            # The ranks of the participants will be set once the rendezvous is
            # complete.
            state.participants[self._node] = 0
    
            self._keep_alive()
    
            if len(state.participants) == self._settings.min_nodes:
                state.deadline = datetime.utcnow() + self._settings.timeout.last_call
    
            if len(state.participants) == self._settings.max_nodes:
                self._mark_rendezvous_complete()
    

    4.2 移除参与者

    接受到 REMOVE_FROM_PARTICIPANTS 之后,调用 _remove_from_participants 从 participants 和 last_heartbeats 中删除参与者。

        def _remove_from_participants(self) -> None:
    
            state = self._state
            del state.participants[self._node]
            del state.last_heartbeats[self._node]
    
            if state.complete:
                # If we do not have any participants left, move to the next round.
                if not state.participants:
                    state.complete = False
                    state.round += 1
            else:
                if len(state.participants) < self._settings.min_nodes:
                    state.deadline = None
    

    4.3 加入等待序列

    接受到 ADD_TO_WAIT_LIST 之后,调用 _add_to_wait_list 网 wait_list 中加入节点。

        def _add_to_wait_list(self) -> None:
            self._state.wait_list.add(self._node)
            self._keep_alive()
    

    4.4 移除等待序列

    接受到 REMOVE_FROM_WAIT_LIST 之后,调用 _remove_from_wait_list 从 wait_list 移除节点。

        def _remove_from_wait_list(self) -> None:
            self._state.wait_list.remove(self._node)
            del self._state.last_heartbeats[self._node]
    

    4.5 设置结束

    接受到 MARK_RENDEZVOUS_COMPLETE 之后,当 rendezvous 聚合操作结束之后,给每一个参与者设置 rank。

    每个节点上都是按照同样算法排序,所以rank在每个节点上都是一样的。

        def _mark_rendezvous_complete(self) -> None:
            state = self._state
    
            state.complete = True
            state.deadline = None
    
            # Assign the ranks.
            for rank, node in enumerate(sorted(state.participants)):
                state.participants[node] = rank
    
        def _mark_rendezvous_closed(self) -> None:
            self._state.closed = True
    

    4.6 心跳

    接收到 KEEP_ALIVE action之后,会调用到 _keep_alive 来维持心跳。另外,keep_alive 也会在 _add_to_participants等方法内被调用,会更新本地state之中的last heartbeats,下一次 sync 时候,会把 last_heartbeats 写入键值存储,这样其他Node就可以知道这个节点的状态了。而本地则会在 _sanitize 之中依据 last_heartbeats 做处理,我们之前提到过。

    def _keep_alive(self) -> None:
        msg = (
            f"The node '{self._node}' updated its keep-alive heartbeat time for the rendezvous "
            f"'{self._settings.run_id}'. Pending sync."
        )
        self._record(message=msg)
        self._state.last_heartbeats[self._node] = datetime.utcnow()
    

    _record 方法如下:

    def _record(self, message: str, node_state: NodeState = NodeState.RUNNING) -> None:
        construct_and_record_rdzv_event(
            name=f"{self.__class__.__name__}.{get_method_name()}",
            run_id=self._settings.run_id,
            message=message,
            node_state=node_state,
            hostname=self._node.fqdn,
            pid=self._node.pid,
            local_id=self._node.local_id,
        )
    

    其就是调用如下代码记录log。

    def record_rdzv_event(event: RdzvEvent) -> None:
        _get_or_create_logger("dynamic_rendezvous").info(event.serialize())
    
    def construct_and_record_rdzv_event(
        run_id: str,
        message: str,
        node_state: NodeState,
        name: str = "",
        hostname: str = "",
        pid: Optional[int] = None,
        master_endpoint: str = "",
        local_id: Optional[int] = None,
        rank: Optional[int] = None,
    ) -> None:
        # We don't want to perform an extra computation if not needed.
        if isinstance(get_logging_handler("dynamic_rendezvous"), logging.NullHandler):
            return
    
        # Set up parameters.
        if not hostname:
            hostname = socket.getfqdn()
        if not pid:
            pid = os.getpid()
    
        # Determines which file called this function.
        callstack = inspect.stack()
        filename = "no_file"
        if len(callstack) > 1:
            stack_depth_1 = callstack[1]
            filename = os.path.basename(stack_depth_1.filename)
            if not name:
                name = stack_depth_1.function
    
        # Delete the callstack variable. If kept, this can mess with python's
        # garbage collector as we are holding on to stack frame information in
        # the inspect module.
        del callstack
    
        # Set up error trace if this is an exception
        if node_state == NodeState.FAILED:
            error_trace = traceback.format_exc()
        else:
            error_trace = ""
    
        # Initialize event object
        event = RdzvEvent(
            name=f"{filename}:{name}",
            run_id=run_id,
            message=message,
            hostname=hostname,
            pid=pid,
            node_state=node_state,
            master_endpoint=master_endpoint,
            rank=rank,
            local_id=local_id,
            error_trace=error_trace,
        )
    
        # Finally, record the event.
        record_rdzv_event(event)
    

    至此,引擎部分也已经分析完毕,下一篇我们看看是否可以从整体角度再做一下全面梳理。

    0xFF 参考

    [源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路

    [源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程

    [源码解析] PyTorch 分布式之弹性训练(3)---代理

    [源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑

  • 相关阅读:
    『转载』优秀ASP.NET程序员的修炼之路
    [转]给年轻工程师的十大忠告
    [转]谈谈技术原则,技术学习方法,代码阅读及其它
    【转贴】你必须知道的20个故事
    谈谈建站心得(转载)[精华]
    HTTP和SOAP完全就是两个不同的协议
    数据集的理解IDataset
    学习在 ArcEngine 中使用 Geoprocessing
    程序执行过程
    How to Run a Geoprocessing Tool
  • 原文地址:https://www.cnblogs.com/rossiXYZ/p/15739391.html
Copyright © 2020-2023  润新知