• [源码解析] PyTorch 分布式之弹性训练(7)节点变化


    [源码解析] PyTorch 分布式之弹性训练(7)---节点变化

    0x00 摘要

    本文分析如何处理节点变化。即对成员更改作出反应,并使用新的成员来重启所有workers,从而实现弹性训练。

    总体思路是和当工作进程失败时的处理一样:相应elastic agent将杀死该节点上的所有工作进程,与其他代理建立会合(rendezvous),并使用新的会合(rendezvous)信息重新启动所有工作进程。

    弹性训练系列文章如下:

    [源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路

    [源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程

    [源码解析] PyTorch 分布式之弹性训练(3)---代理

    [源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑

    [源码解析] PyTorch 分布式之弹性训练(5)---Rendezvous 引擎

    [源码解析] PyTorch 分布式之弹性训练(6)---监控/容错

    0x01 变化方式

    节点变化有两点方式。

    1.1 Scale-down

    节点离开(scale-down)的处理如下:

    • 当Scale down事件发生时,rendezvous将不会通知 torchelastic agent。
    • torchelastic agent 自己会监控到有进程错误,从而进行处理。
    • 如果TE agent以max_restarts=0配置启动,它依赖于底层调度程序来处理作业重新启动。
    • 如果max_restarts>0,TE代理将终止workers并开始新一轮rendezvous。
      • 代理得到离开的通知,于是现有workers(所有节点上的)都全部停止。
      • 这些workers将形成一个新的“WorkerGroup”,所有worker都将以新的RANKWORLD_SIZE 运行。

    1.2 Scale-up

    节点加入(scale-up)的处理如下:

    • 当Scale up事件发生时,新节点被提交到作业,torchelastic rendezvous将检测到有新节点试图加入。
      • 如果rendezvous已经达到最多节点数,新节点将不会添加到等待列表,因为已经满了,所以没有必要拆除已经完全体的rendezvous。新节点将一直等待直到超时(默认为600秒)。
      • 新节点将定期检查参与节点数目。如果数目变为小于max_nodes,等待节点将被加入到等待列表中。否则它将在600秒之后超时。
    • 当代理决定处理 Scale up时:
      • torchelastic rendezvous将停止所有workers并执行新一轮的 re-rendezvous。
      • 这些workers(现有以及新加入的)将形成一个新的“WorkerGroup”,所有worker都将以新的RANKWORLD_SIZE 运行。

    注:scale up发生时,max_restarts 将不会减少。

    0x02 节点加入

    2.1 新节点加入

    假设目前已经有了一个弹性训练集群正在运行,弹性区间为 (min=1, max=4)。目前已经有2个节点在运行,用户想启动第三个节点,于是使用如下方法启动一个新进程。

    python -m torch.distributed.run
            --nnodes=1:4
            --nproc_per_node=$NUM_TRAINERS
            --rdzv_id=$JOB_ID
            --rdzv_backend=c10d
            --rdzv_endpoint=$HOST_NODE_ADDR
            YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
    

    新进程会启动一个代理。代理经过一系列操作,调用 next_rendezvous,其中启动一个 ExitOp,一个 JoinOp 。

    def next_rendezvous(self) -> Tuple[Store, int, int]:
        exit_op = _RendezvousExitOp()
        join_op = _RendezvousJoinOp()
        
        self._op_executor.run(exit_op, deadline)
        self._op_executor.run(join_op, deadline)    
    

    2.2 处理 Join 操作

    以下操作是在 _DistributedRendezvousOpExecutor 之中。

    有了前文分析,我们知道,业务流程是 run 调用 Join 算子来分析出来下一个 Action,然后根据 Action 来执行对应的业务操作

    2.2.1 run处理

    _DistributedRendezvousOpExecutor.run 函数实现了基础逻辑,就是依据 action 类型进行各种操作。对于我们示例,state_handler 就是_RendezvousJoinOp。

        def run(
            self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
        ) -> None:
            """See base class."""
            action = None
    
            while action != _Action.FINISH: # 一直循环,直到结束
                
                # 这里很重要,在所有node之间做信息同步
                has_set = self._state_holder.sync() # 因为最新状态在 rendezvous。
                self._state = self._state_holder.state
                # 利用最新状态构建了 ctx
                ctx = _RendezvousContext(self._node, self._state, self._settings)
    
                # Determine the next action to take based on the current state of
                # the rendezvous.
                action = state_handler(ctx, deadline) # 调用_RendezvousJoinOp,决定下一个操作
    
                # 省略后续部分
    

    2.2.2 Join操作

    因为之前做了同步,所以这里的ctx就包括了最新的state,这就是Rendezvous的全局状态。因为此时,Rendezvous 已经结束了,所以 state 的状态是 complete,进入如下流程,返回 _Action.ADD_TO_WAIT_LIST。

        if state.complete:
            # If we are here, it means we are not part of the rendezvous. In
            # case the rendezvous has capacity for additional participants add
            # ourself to the wait list for the next round.
            if len(state.participants) < ctx.settings.max_nodes: # 如果当前节点数目小于最大配置
                if ctx.node not in state.wait_list: # 如果当前node不在等待列表之中
                    return _Action.ADD_TO_WAIT_LIST  # 发送一个等待action
    

    总体代码如下:

    class _RendezvousJoinOp:
        """Represents a rendezvous join operation."""
    
        def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
            state = ctx.state # 从上下文之中提取 _RendezvousState 状态
    
            # A closed rendezvous means that it no longer accepts new nodes.
            if state.closed:
                return _Action.ERROR_CLOSED # 如果已经结束,就返回 _Action.ERROR_CLOSED
    
            is_participant = ctx.node in state.participants # 看看是参与者
    
            # If we are part of the rendezvous and it is already complete there is
            # no further action to take.
            if state.complete and is_participant: # 如果是参与者且状态结束,就返回 _Action.FINISH
                return _Action.FINISH
    
            now = time.monotonic()
            if now > deadline: # 如果已经超时
                rollback_period = 5  # 5 seconds
    
                # If we still have time to rollback (a short period on top of the
                # operation deadline), try to remove ourself from the rendezvous.
                # It is okay if we can't though as our keep-alive will eventually
                # expire.
                if now <= deadline + rollback_period: # 如果还有时间来 rollback
                    # If we are part of the rendezvous, it means we couldn't find
                    # enough participants to complete it on time.
                    if is_participant: # 已经是参与者了
                        return _Action.REMOVE_FROM_PARTICIPANTS # 需要从参与者列表移除
                    # If we are in the wait list, it means we couldn't wait till the
                    # next round of the rendezvous.
                    if ctx.node in state.wait_list: # 已经在等待列表之中
                        return _Action.REMOVE_FROM_WAIT_LIST # 需要从等待列表移除
                return _Action.ERROR_TIMEOUT # 返回超时
    
            if state.complete: # 如果 rendezvous 已经结束
                # If we are here, it means we are not part of the rendezvous. In
                # case the rendezvous has capacity for additional participants add
                # ourself to the wait list for the next round.
                if len(state.participants) < ctx.settings.max_nodes: # 如果还没有达到最大节点数
                    if ctx.node not in state.wait_list: # 如果当前node不在等待列表之中
                        return _Action.ADD_TO_WAIT_LIST # 就加入到等待列表,发送一个等待action
            elif is_participant: # 如果已经在参与者列表
                # If the rendezvous has enough number of participants including us,
                # check whether we have passed the rendezvous deadline. If yes,
                # complete it.
                if len(state.participants) >= ctx.settings.min_nodes: # 如果达到了最小节点数
                    if cast(datetime, state.deadline) < datetime.utcnow(): # 如果达到了超时
                        return _Action.MARK_RENDEZVOUS_COMPLETE # 标示 rendezvous 已经结束
            else: # 否则就直接加入到参与者
                # The rendezvous is not complete yet and we are not part of it. Try
                # to join.
                return _Action.ADD_TO_PARTICIPANTS
    
            if _should_keep_alive(ctx): # 如果需要保持心跳,就返回 _Action.KEEP_ALIVE
                return _Action.KEEP_ALIVE
    
            # At this point either the rendezvous is not complete, but we are part
            # of it, which means we have to wait for other participants to join; or
            # the rendezvous is complete, but we are not part of it, which means we
            # have to wait for the next round.
            return _Action.SYNC # 否则返回同步状态 _Action.SYNC
    

    2.2.3 等待业务操作

    _DistributedRendezvousOpExecutor 之中,run 函数实现了基础逻辑,就是依据 action 类型进行各种操作。

        def run(
            self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
        ) -> None:
            """See base class."""
            action = None
    
            while action != _Action.FINISH: # 一直循环,直到结束
         
                # 这里很重要,在所有node之间做信息同步
                has_set = self._state_holder.sync() # 因为最新状态在 rendezvous。
                self._state = self._state_holder.state
    					  # 使用最新state构建ctx
                ctx = _RendezvousContext(self._node, self._state, self._settings)
    
                # Determine the next action to take based on the current state of
                # the rendezvous.
                action = state_handler(ctx, deadline) # 调用_RendezvousJoinOp,决定下一个操作,这里得到了 _Action.ADD_TO_WAIT_LIST
    
                if action == _Action.SYNC:
                    _delay(seconds=1)
                else:
                    if action == _Action.KEEP_ALIVE:
                        self._keep_alive()
                    elif action == _Action.ADD_TO_WAIT_LIST: # 从 Join 算子得到了_Action.ADD_TO_WAIT_LIST
                        self._add_to_wait_list() # 进行业务逻辑
                    # 省略其他action
    
                    # Attempt to sync our changes back to other nodes.
                    self._state_holder.mark_dirty() # 同步回其他节点
    

    具体处理等待操作就是加入到等待列表。

    def _add_to_wait_list(self) -> None:
        self._state.wait_list.add(self._node)
        self._keep_alive()
    

    我们回忆一下 _RendezvousState。_RendezvousState 是rendezvous的状态。是动态信息。

    • round:Rendezvous的当前轮次
    • complete:一个布尔值,指示rendezvous当前一轮是否完成了。
    • deadline:截止时间,如果如果当前轮次一直在等待节点加入,如果这个参数设置了,就是等待的截至时间。
    • closed:一个布尔值,指示rendezvous是否结束了。
    • participants:字典,存放参与者和它们对应ranks。
    • wait_list:set结构,存放等待参与下一轮rendezvous操作的一组节点
    • last_heartbeats:字典,包含每个节点上次心跳时间。
    class _RendezvousState:
        round: int
        complete: bool
        deadline: Optional[datetime]
        closed: bool
        participants: Dict[_NodeDesc, int] # 参与者,未来会用到的成员变量
        wait_list: Set[_NodeDesc]  # 等待者,这里用到的成员变量
        last_heartbeats: Dict[_NodeDesc, datetime]
    
        def __init__(self) -> None:
            self.round = 0
            self.complete = False
            self.deadline = None
            self.closed = False
            self.participants = {}
            self.wait_list = set() # 这里用到的成员变量
            self.last_heartbeats = {}
    

    目前逻辑如下:

    1. 启动一个新 worker。此时下图右侧上方的 _RendezvousState 之中,wait_list 为空。
    2. 调用 next_rendezvous,发起新一轮 rendezvous。
    3. _RendezvousJoinOp 内部运行,生成 ADD_TO_WAIT_LIST。
    4. executor . run 内部运行 _add_to_wait_list。
    5. 往 wait_list 添加一个新的 node。此时下图右侧上方的 _RendezvousState 之中,wait_list 多了一个 1。
      python -m torch.distributed.run             +-------------------------+     +
          --nnodes=xxx TRAINING_SCRIPT.py         | _RendezvousState        |     |
                     +                            |                         |     |
                     |                            |    participants = [1,2] |     |
                     | 1                          |                         |     |
                     v                            |    wait_list = []       |     |
              next_rendezvous                     |                         |     |
                     +                            +------------+------------+     |
                     | 2                                       |                  |
                     |                                         |                  |
                     v                                         |                  |
    +----------------+-----------------------+                 |                  |
    | _op_executor.run(_RendezvousJoinOp)    |                 |                  |
    |           +              +             |                 |                  |
    |           |              | 3           |                 |                  |
    |           |              |             |                 |                  |
    |           |              v             |                 |                  |
    |           |   _Action.ADD_TO_WAIT_LIST |                 v                  |
    |           |              +             |                                    |
    |           |              |             |    +--------------------------+    |
    |           +<-------------+             |    | _RendezvousState         |    |
    |           |                            |    |                          |    |
    |           |                            |    |    participants = [1,2]  |    |
    |           v       4                    | 5  |                          |    |
    |      self._add_to_wait_list() +----------------> wait_list = [3]       |    |
    |                                        |    |                          |    |
    +----------------------------------------+    +--------------------------+    |
                                                                                  |
                                                                                  v
    
                                                                             Timeline
    

    2.3 Agent 处理

    _DistributedRendezvousOpExecutor . run 处理之后,操作回到了代理之中。代理主循环之中,程序会进入 while 循环,然后通过 _monitor_workers 定期轮训用户程序运行情况,依据情况作出判断。

        def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
            # NOTE: currently only works for a single role
    
            spec = self._worker_group.spec
            role = spec.role
    
            self._initialize_workers(self._worker_group) # 启动worker
            monitor_interval = spec.monitor_interval
            rdzv_handler = spec.rdzv_handler
    
            while True:
                assert self._worker_group.state != WorkerState.INIT
                # 定期监控
                time.sleep(monitor_interval)
                # 监控客户程序运行情况
                run_result = self._monitor_workers(self._worker_group)
                state = run_result.state # 进程运行情况
                self._worker_group.state = state
    
                if state == WorkerState.SUCCEEDED:
                    # 程序正常结束
                    self._exit_barrier()
                    return run_result
                elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
                    # 程序出错
                    if self._remaining_restarts > 0: # 重试
                        self._remaining_restarts -= 1
                        self._restart_workers(self._worker_group)
                    else:
                        self._stop_workers(self._worker_group) # 重试次数达到,结束workers
                        self._worker_group.state = WorkerState.FAILED
                        self._exit_barrier()
                        return run_result
                elif state == WorkerState.HEALTHY:
    								# 程序正常运行
                    # 节点成员关系有变化,比如scale up
                    # membership changes do not count as retries
                    num_nodes_waiting = rdzv_handler.num_nodes_waiting()
                    group_rank = self._worker_group.group_rank
                    # 如果有新的节点在waiting,就重启所有workers
                    if num_nodes_waiting > 0:
                        self._restart_workers(self._worker_group)
                else:
                    raise Exception(f"[{role}] Worker group in {state.name} state")
    

    所以,代理定期运行 _monitor_workers 监控worker运行情况才是关键。run_result.state 是进程运行情况,当状态是 WorkerState.HEALTHY,说明原有程序正常运行,接下来看看节点成员关系是否有变化。

    调用 rdzv_handler.num_nodes_waiting() 拿到等待列表数目,如果有新的节点在waiting,就说明有新的节点试图加入集群,这时就会发生一个Re-rendezvous。代理将重启所有workers。重启时候,会把等待列表中的节点加入到参与列表之中。我们依次看看如何处理。

    2.3.1 检查等待列表

    处理时候,首先会调用 num_nodes_waiting 看看还有多少节点在等待,具体是看看 state.wait_list 的长度。我们通过之前 Join 操作知道,如果有新节点,会插入到这个列表之中。

    num_nodes_waiting 方法的作用是 返回在 rendezvous barrier 上等待的节点数目(这些节点不会在当前工作组被包括)。调用者应该周期调用这个方法,来确定是否有新节点等候加入当前工作组,因此需要调用next_rendezvous() 来提交他们。

    def num_nodes_waiting(self) -> int:
        """See base class."""
        with self._heartbeat_lock:
            self._state_holder.sync()
    
            return len(self._state_holder.state.wait_list)
    

    目前逻辑如下:

    1. 启动一个新 worker。
    2. 调用 next_rendezvous,发起新一轮 rendezvous。
    3. _RendezvousJoinOp 内部运行,生成 ADD_TO_WAIT_LIST。
    4. executor.run 内部运行 _add_to_wait_list。
    5. 往 wait_list 添加一个新的 node。
    6. Agent 之中,定期(比如 30S)运行一次 _monitor_workers,获取worker 子进程状态。
    7. 如果是 HEALTHY,则调用num_nodes_waiting 获取 wait_list 个数。
    8. 如果 wait_list 之中等待节点数目大于 0,则:
    9. 调用 _restart_workers 重启进程组。
      python -m torch.distributed.run             +-------------------------+     +
          --nnodes=xxx TRAINING_SCRIPT.py         | _RendezvousState        |     |
                     +                            |                         |     |
                     |                            |    participants = [1,2] |     |
                     | 1                          |                         |     |
                     v                            |    wait_list = []       |     |
              next_rendezvous                     |                         |     |
                     +                            +------------+------------+     |
                     | 2                                       |                  |
                     |                                         |                  |
                     v                                         |                  |
    +----------------+-----------------------+                 |                  |
    | _op_executor.run(_RendezvousJoinOp)    |                 |                  |
    |           +              +             |                 |                  |
    |           |              | 3           |                 |                  |
    |           |              |             |                 |                  |
    |           |              v             |                 |                  |
    |           |   _Action.ADD_TO_WAIT_LIST |                 v                  |
    |           |              +             |                                    |
    |           |              |             |    +--------------------------+    |
    |           +<-------------+             |    | _RendezvousState         |    |
    |           |                            |    |                          |    |
    |           |                            |    |    participants = [1,2]  |    |
    |           v       4                    | 5  |                          |    |
    |      self._add_to_wait_list() +----------------> wait_list = [3]       |    |
    |                                        |    |                          |    |
    +----------------------------------------+    +------------+-------------+    |
                                                               |                  |
    +----------------------------------------+                 |                  |
    | agent._invoke_run                      |                 |                  |
    |                                        |                 |                  |
    |                                        |                 |                  |
    |        _monitor_workers Every 30S      |                 |                  |
    |                +                       |                 |                  |
    |                | 6                     |                 |                  |
    |                |                       |                 v                  |
    |                v                       |                                    |
    |         WorkerState.HEALTHY            |     +--------------------------+   |
    |                +                       |     | _RendezvousState         |   |
    |                |                       |     |                          |   |
    |                | 7                     |     |     participants = [1,2] |   |
    |                v                       |  8  |                          |   |
    |        num_nodes_waiting   <-------------------->  wait_list = [3]      |   |
    |                +                       |     |                          |   |
    |                | 9                     |     |                          |   |
    |                |                       |     +--------------------------+   |
    |                v                       |                                    |
    |        _restart_workers                |                                    v
    |                                        |
    +----------------------------------------+                               Timeline
    

    2.3.3 重启worker组

    如果等待列表之中有节点,就会重启workers。我们走一下这个流程。

    @prof
    def _restart_workers(self, worker_group: WorkerGroup) -> None:
        """
        Restarts (stops, rendezvous, starts) all local workers in the group.
        """
    
        role = worker_group.spec.role
        self._stop_workers(worker_group)
        worker_group.state = WorkerState.STOPPED
        self._initialize_workers(worker_group)
    
    2.3.3.1 _stop_workers

    首先会停止目前 workers,代码在torch/distributed/elastic/agent/server/local_elastic_agent.py。

    @prof
    def _stop_workers(self, worker_group: WorkerGroup) -> None:
        self._shutdown()
    
    2.3.3.2 _shutdown

    _shutdown 就是让上下文关闭。

    def _shutdown(self) -> None:
        if self._pcontext:
            self._pcontext.close()
    
    2.3.3.3 关闭上下文

    在 MultiprocessContext 之中,close 方法是关闭所有子进程,然后等待其全部停止。

        def _close(self) -> None:
            if self._pc:
                for proc in self._pc.processes:
                    proc.terminate()
                    proc.join()
    
    2.3.3.4 _initialize_workers

    当关闭了所有当前运行的子进程之后,会重新全部初始化。

    @prof
    def _initialize_workers(self, worker_group: WorkerGroup) -> None:
        r"""
        Starts a fresh set of workers for the worker_group.
        Essentially a rendezvous followed by a start_workers.
    
        The caller should first call ``_stop_workers()`` to stop running workers
        prior to calling this method.
    
        Optimistically sets the state of the worker group that
        just started as ``HEALTHY`` and delegates the actual monitoring
        of state to ``_monitor_workers()`` method
        """
        role = worker_group.spec.role
    
        # TODO after stopping workers, wait at least monitor_interval*2 for
        # workers on different nodes to fail on a collective op before waiting
        # on the rdzv barrier, this way we ensure that nodes enter rdzv
        # at around the same time and reduce false positive rdzv timeout errors
        self._rendezvous(worker_group)
    
        worker_ids = self._start_workers(worker_group)
        for local_rank, w_id in worker_ids.items():
            worker = worker_group.workers[local_rank]
            worker.id = w_id
    
        worker_group.state = WorkerState.HEALTHY
    

    _rendezvous经过一系列操作,调用 next_rendezvous,在其中启动一个 ExitOp,一个 JoinOp 。

    def next_rendezvous(self) -> Tuple[Store, int, int]:
    
        exit_op = _RendezvousExitOp()
        join_op = _RendezvousJoinOp()
        
        self._op_executor.run(exit_op, deadline)
        self._op_executor.run(join_op, deadline)    
    
    2.3.3.5 _RendezvousJoinOp

    我们又回来了,这是新一轮 Rendezvous 操作。_DistributedRendezvousOpExecutor 之中,run 函数实现了基础逻辑,就是依据 action 类型进行各种操作。对于我们示例,state_handler 就是_RendezvousJoinOp

    def run(
        self, state_handler: Callable[[_RendezvousContext, float], _Action], deadline: float
    ) -> None:
        """See base class."""
        action = None
    
        while action != _Action.FINISH:
            # Reads or writes the latest rendezvous state shared by all nodes in
            # the rendezvous. Note that our local changes might get overridden
            # by another node if that node synced its changes before us.
            has_set = self._state_holder.sync()
            self._state = self._state_holder.state
            ctx = _RendezvousContext(self._node, self._state, self._settings)
    
            # Determine the next action to take based on the current state of
            # the rendezvous.
            # 调用到_RendezvousJoinOp,大家可以过一下 _RendezvousJoinOp 代码,发现此时将返回 ADD_TO_PARTICIPANTS
            action = state_handler(ctx, deadline) 
    
            if action == _Action.SYNC:
                # Delay the execution by one second to avoid overloading the
                # backend if we are asked to poll for state changes.
                _delay(seconds=1)
            else:
                if action == _Action.KEEP_ALIVE:
                    self._keep_alive()
                elif action == _Action.ADD_TO_PARTICIPANTS: # 运行到这里
                    self._add_to_participants()
                elif action == _Action.ADD_TO_WAIT_LIST:
                    self._add_to_wait_list()
                elif action == _Action.REMOVE_FROM_PARTICIPANTS:
                    self._remove_from_participants()
                elif action == _Action.REMOVE_FROM_WAIT_LIST:
                    self._remove_from_wait_list()
                elif action == _Action.MARK_RENDEZVOUS_COMPLETE:
                    self._mark_rendezvous_complete()
                elif action == _Action.MARK_RENDEZVOUS_CLOSED:
                    self._mark_rendezvous_closed()
    
                # Attempt to sync our changes back to other nodes.
                self._state_holder.mark_dirty()
    

    这次会生成 ADD_TO_PARTICIPANTS。

    class _RendezvousJoinOp:
        """Represents a rendezvous join operation."""
    
        def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
            state = ctx.state # 从上下文之中提取 _RendezvousState 状态
    
            # A closed rendezvous means that it no longer accepts new nodes.
            if state.closed:
                return _Action.ERROR_CLOSED # 如果已经结束,就返回 _Action.ERROR_CLOSED
    
            is_participant = ctx.node in state.participants # 看看是参与者
    
            # If we are part of the rendezvous and it is already complete there is
            # no further action to take.
            if state.complete and is_participant: # 如果是参与者且状态结束,就返回 _Action.FINISH
                return _Action.FINISH
    
            now = time.monotonic()
            if now > deadline: # 如果已经超时
                rollback_period = 5  # 5 seconds
    
                # If we still have time to rollback (a short period on top of the
                # operation deadline), try to remove ourself from the rendezvous.
                # It is okay if we can't though as our keep-alive will eventually
                # expire.
                if now <= deadline + rollback_period: # 如果还有时间来 rollback
                    # If we are part of the rendezvous, it means we couldn't find
                    # enough participants to complete it on time.
                    if is_participant: # 已经是参与者了
                        return _Action.REMOVE_FROM_PARTICIPANTS # 需要从参与者列表移除
                    # If we are in the wait list, it means we couldn't wait till the
                    # next round of the rendezvous.
                    if ctx.node in state.wait_list: # 已经在等待列表之中
                        return _Action.REMOVE_FROM_WAIT_LIST # 需要从等待列表移除
                return _Action.ERROR_TIMEOUT # 返回超时
    
            if state.complete: # 如果 rendezvous 已经结束
                # If we are here, it means we are not part of the rendezvous. In
                # case the rendezvous has capacity for additional participants add
                # ourself to the wait list for the next round.
                if len(state.participants) < ctx.settings.max_nodes: # 如果还没有达到最大节点数
                    if ctx.node not in state.wait_list: # 如果当前node不在等待列表之中
                        return _Action.ADD_TO_WAIT_LIST # 就加入到等待列表,发送一个等待action
            elif is_participant: # 如果已经在参与者列表
                # If the rendezvous has enough number of participants including us,
                # check whether we have passed the rendezvous deadline. If yes,
                # complete it.
                if len(state.participants) >= ctx.settings.min_nodes: # 如果达到了最小节点数
                    if cast(datetime, state.deadline) < datetime.utcnow(): # 如果达到了超时
                        return _Action.MARK_RENDEZVOUS_COMPLETE # 标示 rendezvous 已经结束
            else: # 否则就直接加入到参与者
                # The rendezvous is not complete yet and we are not part of it. Try
                # to join.
                return _Action.ADD_TO_PARTICIPANTS
    
            if _should_keep_alive(ctx): # 如果需要保持心跳,就返回 _Action.KEEP_ALIVE
                return _Action.KEEP_ALIVE
    
            # At this point either the rendezvous is not complete, but we are part
            # of it, which means we have to wait for other participants to join; or
            # the rendezvous is complete, but we are not part of it, which means we
            # have to wait for the next round.
            return _Action.SYNC # 否则返回同步状态 _Action.SYNC
    
    2.3.3.6 _add_to_participants

    引擎收到 ADD_TO_PARTICIPANTS 之后,会调用 _add_to_participants 从 wait_list 移除节点,插入到 participants。

    def _add_to_participants(self) -> None:
        log.debug(
            f"The node '{self._node}' added itself to the participants of round "
            f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync."
        )
    
        state = self._state
        state.wait_list.remove(self._node) # 移除节点
    
        # The ranks of the participants will be set once the rendezvous is
        # complete.
        state.participants[self._node] = 0 # 重新插入
    
        self._keep_alive()
    
        if len(state.participants) == self._settings.min_nodes:
            state.deadline = datetime.utcnow() + self._settings.timeout.last_call
    
        if len(state.participants) == self._settings.max_nodes:
            self._mark_rendezvous_complete()
    

    我们这次从 _restart_workers 开始绘制。

    1. 调用 _stop_workers 来关闭worker子进程。此时下图右侧上方 _RendezvousState之中,participants=[1,2]。
    2. 通过 MultiprocessContext.close() 完成关闭操作。
    3. 通过 _initialize_workers 重新初始化 worker。
    4. 调用 next_rendezvous 完成新的同步操作。
    5. _RendezvousJoinOp 这次返回ADD_TO_PARTICIPANTS。
    6. 调用 _add_to_participants 进行状态切换。
    7. wait_list 之中的Node被移动到 participants。此时下图右侧上方 _RendezvousState之中,participants=[1,2,3]。
                             +-----------------------------+   +------------------------+  |
                             |  agent._invoke_run          |   | _RendezvousState       |  |
                             |                             |   |                        |  |
                             |       _restart_workers      |   |   participants = [1,2] |  |
                             |              +              |   |                        |  |
    +----------------------+ |              |              |   |   wait_list = [3]      |  |
    | MultiprocessContext  | |              | 1            |   |                        |  |
    |                      | | 2            v              |   +------------------------+  |
    |        close()  <-----------+  _stop_workers         |                               |
    |                      | |              +              |                               |
    +----------------------+ |              |              |                               |
                             |              | 3            |                               |
                             |              v              |                               |
                             |     _initialize_workers     |                               |
                             |              +              |                               |
                             |              |              |                               |
                             +-----------------------------+                               |
                                            |                                              |
                                            | 4                                            |
                                            v                                              |
                                     next_rendezvous                                       |
                                            +                                              |
                                            |                                              |
                                            v                                              |
                +---------------------------+---------------+                              |
                | _op_executor.run(_RendezvousJoinOp)       |                              |
                |           +               +               |                              |
                |           |               |               |                              |
                |           |               | 5             |                              |
                |           |               v               |                              |
                |           |       ADD_TO_PARTICIPANTS     |                              |
                |           |               +               |   +-----------------------+  |
                |           |               |               |   | _RendezvousState      |  |
                |           | <-------------+               |   |                       |  |
                |           |                               |   | participants = [1,2,3]|  |
                |           v     6                  7      |   |                       |  |
                |        _add_to_participants  +--------------> | wait_list = []        |  |
                |                                           |   |                       |  |
                +-------------------------------------------+   +-----------------------+  v
    
                                                                                     Timeline
    
    
    

    0x03 节点离开

    3.1 处理机制

    节点离开(scale-down)的处理如下:

    • 当Scale down事件发生时,rendezvous将不会通知 torchelastic agent。
    • 如果TE agent以“max_restarts=0”启动,它依赖于底层调度程序来处理作业重新启动。
    • 如果“max_restarts>0”,TE代理将终止workers并开始新一轮rendezvous。
      • 代理得到离开的通知,于是现有workers(所有节点上)都全部停止。
      • 这些workers将形成一个新的“WorkerGroup”,所有worker都将以新的RANKWORLD_SIZE 运行。、

    3.2 如何模拟

    如果想模拟调试的同学,可以在 test/distributed/elastic/agent/server/test/local_elastic_agent_test.py 之中找到示例代码。

    def test_double_agent_elastic(self):
        """
        start ``nnodes`` agents, kill odd ones (do not restart), validate
        elasticity (scale-down) works. (scale-up covered in fault_tolerance test)
        """
        min_nodes = 1
        max_nodes = 2
        wait = 2
        node_conf = Conf(entrypoint=_dist_sum, args=(wait,), local_world_size=2)
        agent_results = mp.Queue()
        agent_args = {
            "conf": node_conf,
            "agent_results": agent_results,
            "min_nodes": min_nodes,
            "max_nodes": max_nodes,
            "max_restarts": 2,
        }
    
        procs = []
        for _ in range(max_nodes):
            p = mp.Process(
                target=self.run_agent,
                kwargs=agent_args,
            )
            procs.append(p)
            p.start()
    
        # kill odd agents
        for i in range(max_nodes):
            if i % 2 != 0:
                procs[i].kill()
    
        for i in range(max_nodes):
            p = procs[i]
            p.join()
            if i % 2 == 0:
                self.assertEqual(0, p.exitcode)
            else:
                self.assertEqual(-signal.SIGKILL, p.exitcode)
    

    3.3 如何处理

    节点离开,与错误处理是同一个代码。错误处理代码如下,如果重试尚未达到最大次数,则试图重启workers。如果已经达到了最大次数,则停止 workers。

        def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
            
            # 省略
         
            while True:
    
                # 定期监控
                time.sleep(monitor_interval)
                # 监控客户程序运行情况
                run_result = self._monitor_workers(self._worker_group)
                
                elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
                # 程序出错
                
                if self._remaining_restarts > 0: # 重试
                    self._remaining_restarts -= 1
                    self._restart_workers(self._worker_group) # 进行重启
                else:
                    self._stop_workers(self._worker_group) # 重试次数达到,结束workers
                    self._worker_group.state = WorkerState.FAILED
                    self._exit_barrier()
                    return run_result
    

    3.3.1 重启

    _restart_workers 会停掉所有 workers,然后重新一轮 rendezvous 。

    @prof
    def _restart_workers(self, worker_group: WorkerGroup) -> None:
        """
        Restarts (stops, rendezvous, starts) all local workers in the group.
        """
    
        role = worker_group.spec.role
        self._stop_workers(worker_group)
        worker_group.state = WorkerState.STOPPED
        self._initialize_workers(worker_group)
    

    3.3.2 停止

    停止 workers 就是关闭上下文。

    def _shutdown(self) -> None:
        if self._pcontext:
            self._pcontext.close()
            
    @prof
    def _stop_workers(self, worker_group: WorkerGroup) -> None:
        self._shutdown()
    

    在 MultiprocessContext 之中,close 方法是关闭所有子进程,然后等待其全部停止。

        def _close(self) -> None:
            if self._pc:
                for proc in self._pc.processes:
                    proc.terminate()
                    proc.join()
    

    流程图如下:

    1. 监控子进程状态。
    2. 发现 UNHEALTHY 或者 FAILED,看看重启次数是否还有。我们假定是3号进程失败。
    3. 如果没有,就调用 _stop_workers 结束子进程。
    4. 调用 MultiprocessContext.close 进行具体结束操作。
    5. 如果还可以重启,调用_restart_workers。
    6. 调用 _stop_workers 结束子进程。
    7. 调用 MultiprocessContext.close 进行具体结束操作。
    8. 调用 _initialize_workers 重新初始化worker。
    9. 调用 next_rendezvous 重新同步。
    10. 进行后续操作。
                                                                                     +
    +-------------------------------------------+    +---------------------------+   |
    | agent._invoke_run                         |    | _RendezvousState          |   |
    |                                           |    |                           |   |
    |                                           |    |                           |   |
    |     _monitor_workers Every 30S            |    |    participants = [1,2,3] |   |
    |             +                             |    |                           |   |
    |             | 1                           |    |    wait_list = [ ]        |   |
    |             |                             |    |                           |   |
    |             v                             |    +---------------------------+   |
    |     WorkerState.UNHEALTHY,FAILED          |                                    |
    |             +                             |                                    |
    |             |                             |                                    |
    |             | 2                           |                                    |
    |             v                             |                                    |
    |   self._remaining_restarts > 0 ? +--+     |                                    |
    |             +                       |     |                                    |
    |          5  | YES                NO | 3   |                                    |
    |             |                       |     |                                    |
    |             v                       v     |    +----------------------+        |
    |     _restart_workers        _stop_workers |    | MultiprocessContext  |        |
    |             +                       +     |    |                      |        |
    |             | 6                     |  4  |    |                      |        |
    |             |                       +--------> |                      |        |
    |             v                             |    |        close()       |        |
    |      _stop_workers +-------------------------> |                      |        |
    |             +                 7           |    +----------------------+        |
    |             |                             |                                    |
    |             | 8                           |                                    |
    |             v                             |                                    |
    |    _initialize_workers                    |                                    |
    |             +                             |                                    |
    |             |                             |                                    |
    +-------------------------------------------+                                    |
                  | 9                                                                |
                  |                                                                  |
                  v                                +--------------------------+      |
            next_rendezvous                        | _RendezvousState         |      |
                  +                                |                          |      |
                  |               10               |     participants = [1,2] |      |
                  +---------------------------->   |                          |      |
                  |                                |     wait_list = [ ]      |      v
                  | 10                             +--------------------------+
                  v                                                             Timeline
    

    至此,弹性训练全部分析完毕,或者说PyTorch分布式分析就告一段落,我们下文会介绍其他框架/库的分布式实现,敬请期待。

    0xFF 参考

    [源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路

    [源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程

    [源码解析] PyTorch 分布式之弹性训练(3)---代理

    [源码解析] PyTorch 分布式之弹性训练(4)---Rendezvous 架构和逻辑

    [源码解析] PyTorch 分布式之弹性训练(5)---Rendezvous 引擎

  • 相关阅读:
    NOI2007项链工厂——sbTreap代码
    终于还是卡着进队了
    SCOI RP+=INF
    每日算法——新型在线LCA
    每日算法——并查集的应用
    每日算法--矩阵乘法优化递推
    神一般的数据结构--可持久化treap
    算法竞赛中的数论经典定理
    Baby Step Gaint Step
    素数分组 哥德巴赫猜想
  • 原文地址:https://www.cnblogs.com/rossiXYZ/p/15743246.html
Copyright © 2020-2023  润新知