• [源码解析] PyTorch 分布式之弹性训练(2)启动&单节点流程


    [源码解析] PyTorch 分布式之弹性训练(2)---启动&单节点流程

    0x00 摘要

    在前面的文章之中,我们已经学习了PyTorch 分布式的基本模块,介绍了官方的几个例子,我们接下来会介绍PyTorch的弹性训练,本文是第二篇,重点关注的是如何启动弹性训练,并且可以对系统总体架构有所了解。

    弹性训练系列文章如下:

    [源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路

    0x01 重要概念

    为了更好的说明(这个说明可能在后面文章也会出现,因为太重要了),我们先总述一下TE 最重要的 Agent 和 Rendezvous 两个概念。

    • Agent :Agent是运行在单节点上的独立后台进程,可以认为是 worker manager 或者 process supervisor,其负责启动worker,监控 worker 运行,捕获woker异常,通过 rendezvous 实现 worker 间的相互发现(比如把状态上报到KVStore),成员变动时候基于 rendezvous 进行变更同步等等。
    • Rendezvous :为了实现弹性训练,需要有一个节点/进程之间彼此发现的机制。Rendezvous就是这个发现机制或者说同步组件。当系统启动或者成员变更时候,所有worker会(重新)集合(rendezvous)以建立一个新的进程组。

    我们从源码中取出示意图看看,大家先有一个总体概念。

    0x02 分布式运行

    2.1 方式改变

    2.1.1 原有方式

    我们知道,PET是从 PyTorch v1.9 合并进来的,因为合并了弹性训练,所以分布式启动的方式有了很大的改变。

    V1.9 之前是使用 torch/distributed/launch.py 进行启动,比如:

    python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
               --nnodes=2 --node_rank=0 --master_addr="192.168.1.1"
               --master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
               and all other arguments of your training script)
    

    此处参数含义是:

    • nnodes :是参与训练的节点数目。
    • nproc_per_node :每个节点上运行的进程数目。
    • node_rank :当前节点标识符。
    • master_addrmaster_port 是 master 监听的地址和端口。

    当运行时,torch.distributed.launch 会设置一些环境变量,包括 world_sizemaster_addrmaster_port 等等。然后在当前机器上创建 nproc_per_node 个进程,这些进程构成了一个本地组。如果一共有 NODE_SIZE 个机器参与训练,则一共有 NODE_SIZE * TRAINERS_PER_NODE 个进程。如果想启动一个分布式训练任务,则需要在所有的机器上执行相关命令。

    2.1.2 目前方式

    PyTorch 1.9 使用 torch/distributed/run.py 进行启动。如果依然采用 torch/distributed/launch.py,其实其内部已经透传给 run.py,具体参见代码:

    def main(args=None):
        logger.warn(
            "The module torch.distributed.launch is deprecated "
            "and going to be removed in future."
            "Migrate to torch.distributed.run"
        )
        args = parse_args(args)
        run(args)
    

    torch.distributed.run是之前torch.distributed.launch的一个超集,提供如下新功能:

    • 容错:通过重新启动所有workers,可以优雅地处理worker故障。
    • 自动:Worker 的RANKWORLD_SIZE 是自动分配的
    • 弹性:允许在最小值和最大值(弹性)之间更改节点数。

    为了使用弹性训练,用户代码也需要做一些修改,如果用户的训练脚本已经支持 torch.distributed.launch ,则只需要修改几处就可以使用torch.distributed.run

    • 无需手动传递RANK , WORLD_SIZE , MASTER_ADDR 和 MASTER_PORT。
    • 必须提供rdzv_backendrdzv_endpoint。对于大多数用户来说,这其实就是“c10d”(参见“rendezvous“)。其实这就替代了之前的MASTER_ADDR 和 MASTER_PORT。
    • use_env 参数已被删除。请从 LOCAL_RANK 环境变量中获取local_rank (例如,os.environ["LOCAL_RANK"])。
    • 用户需要确保脚本中有 load_checkpoint(path)save_checkpoint(path) 逻辑,即手动处理Checkpoint。因为当worker失败时,我们将使用最近的checkpoint来恢复现场,重启所有worker。

    下面是一个训练脚本的示例,该脚本在每个epoch上设置检查点,因此在失败时最差也只是会丢失一个epoch的训练成果。

      def main():
           args = parse_args(sys.argv[1:])
           state = load_checkpoint(args.checkpoint_path)
           initialize(state)
    
           # torch.distributed.run ensure that this will work
           # by exporting all the env vars needed to initialize the process group
           torch.distributed.init_process_group(backend=args.backend)
    
           for i in range(state.epoch, state.total_num_epochs)
                for batch in iter(state.dataset)
                    train(batch, state.model)
    
                state.epoch += 1
                save_checkpoint(state)
    

    所以,我们接下来看看在新模式之下,如何分布式启动。

    2.2 部署

    部署一般按照如下方式。

    1. (C10d后端不需要)启动 rendezvous 后端服务器,并获取端点(作为--rdzv_endpoint传递给启动程序脚本)
    2. 单节点多 worker:在主机上启动 launcher 以启动代理进程,代理会创建并监视本地工作组。
    3. 多节点多 worker:在所有节点上使用相同的参数启动 launcher 参加训练。

    当使用作业/群集管理器时,多节点作业的入口点命令应为 launcher。

    2.3 示例

    我们首先通过几个例子来看看如何启动分布式训练。

    2.3.1 单节点多worker启动

    单节点多worker的启动方式如下,其实就是Standalone 模式,这是分布式模式的一种特例,具体就是针对单机多 Worker 提供了一些便利设置。

    python -m torch.distributed.run
            --standalone
            --nnodes=1
            --nproc_per_node=$NUM_TRAINERS
            YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
    

    2.3.2 容错方式启动

    如下是容错方式启动,固定数目workers,没有弹性训练。 --nproc_per_node=$NUM_TRAINERS 一般是 单节点上GPU 个数。

    python -m torch.distributed.run
            --nnodes=$NUM_NODES
            --nproc_per_node=$NUM_TRAINERS
            --rdzv_id=$JOB_ID
            --rdzv_backend=c10d
            --rdzv_endpoint=$HOST_NODE_ADDR
            YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
    

    HOST_NODE_ADDR, 的格式是: [:] ,指定了 C10d rendezvous 后端所运行的节点地址和端口,这个节点可以是训练集群中任意节点,但是最好找一个高带宽的节点。

    2.3.3 弹性方式启动

    下面是弹性训练,弹性区间为 (min=1, max=4)。通过指定rdzv参数,可以实现多机训练,具备容错与弹性能力

    在多台机器上分别执行以下命令启动:最小节点数为MIN_SIZE,最大为MAX_SIZE,利用etcd服务实现一致性和信息同步。

    python -m torch.distributed.run
            --nnodes=1:4
            --nproc_per_node=$NUM_TRAINERS
            --rdzv_id=$JOB_ID
            --rdzv_backend=c10d
            --rdzv_endpoint=$HOST_NODE_ADDR
            YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
    

    HOST_NODE_ADDR, 的格式是: [:] ,指定了 C10d rendezvous 后端所运行的节点地址和端口,这个节点可以是训练集群中任意节点,但是最好找一个高带宽的节点。

    关于 rendezvous backend,有几点说明:

    对于多节点训练,需要指定:

    • --rdzv_id: 一个唯一的 job id,在参与job的所有节点之间共享。
    • --rdzv_backend: torch.distributed.elastic.rendezvous.RendezvousHandler 的一个实现。 (--rdzv_backend默认是static模式,不支持容错和弹性伸缩)
    • --rdzv_endpoint: rendezvous backend 所运行的 endpoint,通常格式为:host:port。就是取代了之前的 master address / port 设置。

    目前,以下几种后端可以直接使用,c10d (推荐), etcd-v2, and etcd (legacy) 。为了使用 etcd-v2 或者 etcd,需要搭建一个 v2 api开启的 etcd server (即. --enable-v2)。

    0x03 启动脚本

    既然以上启动都是用 torch/distributed/run.py,所以我们仔细分析一下这个脚本,该脚本提供三个功能:

    • 依靠"重启所有 workers"来处理 worker 失败;

    • 自动分配 worker 的RANK and WORLD_SIZE

    • 弹性训练,即 node 数目允许在minimum和maximum之间改变;

    3.1 参数定义

    启动脚本中,一些参数定义如下:

    • Node - 物理实例或容器;映射到与 job manager 所协调的单元。
    • Worker - 分布式训练环境中的worker。
    • WorkerGroup - 执行相同功能的一组worker(例如trainers)。
    • LocalWorkerGroup - 在同一节点上运行的工作组中的workers子集。
      • 一个节点运行 LOCAL_WORLD_SIZE个workers,这些 workers 组成LocalWorkerGroup
      • 节点上所有LocalWorkerGroups组成WorkerGroups
    • RANK - 工作组中worker的rank,是全局rank,可以认为是一个全局GPU资源列表。
      • Rank是不稳定的,在重启之间,本地Workers 会被分配到不同的ranks,所以不要在代码中对RANKLOCAL_RANK的稳定性做任何假设和依赖编码。
      • rendezvous完成后,其所有成员将对工作成员资格以及每个人在其中的角色(role)达成共识。此角色(role)使用一个介于 0 ~ world size 之间的整型来表示,被称之为rank。
    • LOCAL_RANK - 本地工作组中,某个worker 的 rank,可以认为是当前节点上的GPU资源列表。
    • GROUP_RANK - worker group的rank。介于0和“最大节点数”之间的数字。如果每个节点运行一个单一工作组,那GROUP_RANK就是这个节点的rank。
    • ROLE_RANK - 对于具有相同角色worker来说,他们之间共享的rank,角色在“WorkerSpec”中被指定。
    • WORLD_SIZE - 工作组中worker的总数。因为节点会加入/离开,所以WORLD_SIZE会变化,不能依赖 WORLD_SIZE的稳定性进行编码。
    • LOCAL_WORLD_SIZE - 本地工作组的大小,即本地运行的worker数目,等于在torch.distributed.run运行时候指定的--nproc_per_node。目前,torch/distributed/run.py 仅支持同构的 LOCAL_WORLD_SIZE。也就是说,假设所有节点运行相同数量的本地工作者(每个角色)。
    • ROLE_WORLD_SIZE - 具有同样角色的workers总数,在 WorkerSpec之中被指定。
    • rdzv_id - 用户定义的id,用于唯一标识作业的工作组。这个id在每个节点加入特定工作组时候使用。
    • rdzv_backend-rendezvous 的后端(例如“c10d”)。这通常是一个强一致性的键值存储。
    • rdzv_endpoint - rendezvous 后端端点;通常以“<host>:<port>”的形式出现。
    • run_id: 用户定义的id,它唯一地标识分布式应用程序的一个实例。它通常映射到作业id并用于允许节点加入正确的分布式应用程序。
    • TORCHELASTIC_RUN_ID - 与 rendezvous run_id 相等,即唯一的job id。
    • TORCHELASTIC_RESTART_COUNT - 迄今为止,工作组重启的次数。
    • TORCHELASTIC_MAX_RESTARTS - 配置的最大重启数目。

    3.2 相关函数/变量

    为了更好的理解上面的参数,我们选取部分相关函数/变量看看。

    world_size,rank

    这两个变量是动态生成的,所以从 state 之中取出。

    rank, world_size = self._get_world()
        
    def _get_world(self) -> Tuple[int, int]:
    	state = self._state_holder.state
    	return state.participants[self._this_node], len(state.participants)
    

    _pg_group_ranks

    该全局变量存储了每个 group 的 global rank 到 local rank 映射信息。

    # Process group's global rank to local rank mapping
    _pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}
    

    其赋值举例如下:

    # Create the global rank to group rank mapping
    _pg_group_ranks[pg] = {
        global_rank: group_rank
        for group_rank, global_rank in enumerate(ranks)
    }
    

    group_rank

    我们可以利用 global rank 从 _pg_group_ranks 之中提取对应的 local rank。

    def _get_group_rank(group: ProcessGroup, rank):
        """
        Helper that gets a given group's local rank in the group from a given global
        rank.
        """
        if group is GroupMember.WORLD:
            raise RuntimeError("group.WORLD does not have local rank to global "
                               "rank mapping")
        if group not in _pg_group_ranks:
            raise RuntimeError("The given group does not exist")
        try:
            group_rank = _pg_group_ranks[group][rank]
        except KeyError:
            raise RuntimeError(f"The global rank {rank} is not part of the group {group}") from None
        return group_rank
    

    global_rank

    我们可以利用一个 group 的 local rank 获取到其 gloabl rank。

    def _get_global_rank(group, group_rank):
        """
        Helper that gets a given group's global rank from a given local rank in the
        group.
        """
        if group is GroupMember.WORLD:
            raise RuntimeError("group.WORLD does not have local rank to global "
                               "rank mapping")
        group_rank_map = _pg_group_ranks[group]
        for rank, grp_rank in group_rank_map.items():
            if grp_rank == group_rank:
                return rank
        raise RuntimeError("The group rank is not part of the group")
    

    group_size

    我们可以 _get_group_size 获取到某一个group 的大小。

    def _get_group_size(group):
        """
        Helper that gets a given group's world size.
        """
        if group is GroupMember.WORLD or group is None:
            default_pg = _get_default_group()
            return default_pg.size()
        if group not in _pg_group_ranks:
            raise RuntimeError("The given group does not exist")
        return len(_pg_group_ranks[group])
    

    nproc_per_node

    这个变量可以得到每个node之上支持多少个进程。

    def determine_local_world_size(nproc_per_node: str):
        try:
            logging.info(f"Using nproc_per_node={nproc_per_node}.")
            return int(nproc_per_node)
        except ValueError:
            if nproc_per_node == "cpu":
                num_proc = os.cpu_count()
                device_type = "cpu"
            elif nproc_per_node == "gpu":
                if not torch.cuda.is_available():
                    raise ValueError("Cuda is not available.")
                device_type = "gpu"
                num_proc = torch.cuda.device_count()
            elif nproc_per_node == "auto":
                if torch.cuda.is_available():
                    num_proc = torch.cuda.device_count()
                    device_type = "gpu"
                else:
                    num_proc = os.cpu_count()
                    device_type = "cpu"
            else:
                raise ValueError(f"Unsupported nproc_per_node value: {nproc_per_node}")
            )
            return num_proc
    

    3.3 脚本入口

    脚本入口主要代码如下,可以看到,其调用到了 elastic_launch 来完成功能,所以我们下一节就要顺藤摸瓜来看看这个函数。

    from torch.distributed.launcher.api import LaunchConfig, elastic_launch
    
    def run(args):
        if args.standalone: # 有两种模式:Standalone 模式和分布式模式,这里要判断一下
            args.rdzv_backend = "c10d"
            args.rdzv_endpoint = "localhost:29400"
            args.rdzv_id = str(uuid.uuid4())
            log.info(
                f"\n**************************************\n"
                f"Rendezvous info:\n"
                f"--rdzv_backend={args.rdzv_backend} "
                f"--rdzv_endpoint={args.rdzv_endpoint} "
                f"--rdzv_id={args.rdzv_id}\n"
                f"**************************************\n"
            )
    
        config, cmd, cmd_args = config_from_args(args)
        elastic_launch(
            config=config,
            entrypoint=cmd,
        )(*cmd_args)
    
    
    def main(args=None):
        args = parse_args(args)
        run(args)
    
    
    if __name__ == "__main__":
        logging.basicConfig(
            level=logging.INFO, format="[%(levelname)s] %(asctime)s %(module)s: %(message)s"
        )
        main()
    

    0x04 单体总体流程

    我们下面就从 elastic_launch 开始,看看在单节点上如何启动运行。我们首先给出一个总体示意图,图上是两个节点,每个节点有一个 agent,agent下面是一个 worker group,组下面是4个worker。

    4.1 小例子

    我们再从源码中找一个例子来看看,这里只是设置了两个workers。

    import uuid
    import torch
    from torch.distributed.launcher.api import LaunchConfig, elastic_launch
    
    def worker_fn(t1, t2):
        return torch.add(t1, t2)
    
    def main():
        t1 = torch.rand((3,3), requires_grad=True)
        t2 = torch.rand((3, 3), requires_grad=True)
    
        config = LaunchConfig(
            min_nodes=2,
            max_nodes=4,
            nproc_per_node=1,
            run_id=str(uuid.uuid4()),
            role="trainer",
            rdzv_endpoint="localhost:29400",
            rdzv_backend="c10d",
            max_restarts=1,
            monitor_interval=1,
            start_method="spawn",
        )
    
        outputs = elastic_launch(config, worker_fn)(t1, t2)
    
    if __name__ == '__main__':
        main()
    
    

    输出如下,可以看到有两个 worker 进程 和一个 agent 进程。

    {"name": "torchelastic.worker.status.SUCCEEDED", "source": "WORKER", "timestamp": 0, "metadata": {"run_id": "7fbf85fe-b8b3-462e-887e-8121e3062e0b", "global_rank": 0, "group_rank": 0, "worker_id": "12172", "role": "trainer", "hostname": "DESKTOP-0GO3RPO", "state": "SUCCEEDED", "total_run_time": 31, "rdzv_backend": "c10d", "raw_error": null, "metadata": "{\"group_world_size\": 1, \"entry_point\": \"worker_fn\", \"local_rank\": [0], \"role_rank\": [0], \"role_world_size\": [2]}", "agent_restarts": 0}}
    
    {"name": "torchelastic.worker.status.SUCCEEDED", "source": "WORKER", "timestamp": 0, "metadata": {"run_id": "7fbf85fe-b8b3-462e-887e-8121e3062e0b", "global_rank": 1, "group_rank": 0, "worker_id": "3276", "role": "trainer", "hostname": "DESKTOP-0GO3RPO", "state": "SUCCEEDED", "total_run_time": 31, "rdzv_backend": "c10d", "raw_error": null, "metadata": "{\"group_world_size\": 1, \"entry_point\": \"worker_fn\", \"local_rank\": [1], \"role_rank\": [1], \"role_world_size\": [2]}", "agent_restarts": 0}}
    
    {"name": "torchelastic.worker.status.SUCCEEDED", "source": "AGENT", "timestamp": 0, "metadata": {"run_id": "7fbf85fe-b8b3-462e-887e-8121e3062e0b", "global_rank": null, "group_rank": 0, "worker_id": null, "role": "trainer", "hostname": "DESKTOP-0GO3RPO", "state": "SUCCEEDED", "total_run_time": 31, "rdzv_backend": "c10d", "raw_error": null, "metadata": "{\"group_world_size\": 1, \"entry_point\": \"worker_fn\"}", "agent_restarts": 0}}
    

    4.2 入口

    顺着代码我们深入挖掘一下。elastic_launch 的作用就是启动一个 torchelastic agent,然后通过这个 agent来调用用户程序入口,agent 会启动 worker 进行训练,并且管理 worker 生命周期

    class elastic_launch:
        """
        Launches an torchelastic agent on the container that invoked the entrypoint.
    
            1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/
               ``entrypoint`` can be a function or a command.
            2. The return value is a map of each worker's output mapped
               by their respective global rank.
        """
    
        def __init__(
            self,
            config: LaunchConfig,
            entrypoint: Union[Callable, str, None],
        ):
            self._config = config
            self._entrypoint = entrypoint
    
        def __call__(self, *args, **kwargs):
            return launch_agent(self._config, self._entrypoint, list(args)) # 内部会调用用户程序
    

    4.3 启动代理

    launch_agent 启动了一个 LocalElasticAgent,调用了其 run 方法。

    @record
    def launch_agent(
        config: LaunchConfig,
        entrypoint: Union[Callable, str, None],
        args: List[Any],
    ) -> Dict[int, Any]:
        if not config.run_id:
            run_id = str(uuid.uuid4().int)
            config.run_id = run_id
    
        entrypoint_name = _get_entrypoint_name(entrypoint, args)
    
        rdzv_parameters = RendezvousParameters(
            backend=config.rdzv_backend,
            endpoint=config.rdzv_endpoint,
            run_id=config.run_id,
            min_nodes=config.min_nodes,
            max_nodes=config.max_nodes,
            **config.rdzv_configs,
        )
    
        agent = None
        rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_parameters)
        master_addr, master_port = _get_addr_and_port(rdzv_parameters)
        try:
            spec = WorkerSpec( # 1. 得到spec
                role=config.role,
                local_world_size=config.nproc_per_node,
                entrypoint=entrypoint,
                args=tuple(args),
                rdzv_handler=rdzv_handler, # RendezvousHandler
                max_restarts=config.max_restarts,
                monitor_interval=config.monitor_interval,
                redirects=config.redirects,
                tee=config.tee,
                master_addr=master_addr,
                master_port=master_port,
            )
    
            cfg = metrics.MetricsConfig(config.metrics_cfg) if config.metrics_cfg else None
            metrics.initialize_metrics(cfg)
    
            agent = LocalElasticAgent( # 2. 构建代理
                spec=spec, start_method=config.start_method, log_dir=config.log_dir
            )
    
            result = agent.run() # 3. 启动代理
            events.record(agent.get_agent_status_event(WorkerState.SUCCEEDED))
            if result.is_failed():
                # ChildFailedError is treated specially by @record
                # if the error files for the failed children exist
                # @record will copy the first error (root cause)
                # to the error file of the launcher process.
                raise ChildFailedError(
                    name=entrypoint_name,
                    failures=result.failures,
                )
            else:
                return result.return_values
        except ChildFailedError:
            raise
        except Exception:
            if agent:
                events.record(agent.get_agent_status_event(WorkerState.FAILED))
            else:
                events.record(_construct_event(config))
            raise
        finally:
            rdzv_handler.shutdown()
    

    这里有几个关键点:

    4.3.1 WorkerSpec

    WorkerSpec :这是配置信息,里面包含了代理所需要的某些全局信息,比如 RendezvousHandler,role,entry(用户函数)。

    spec = {WorkerSpec} 
       args = {tuple: 2} (tensor, tensor)
       fn = {NoneType} None
       local_world_size = {int} 1
       master_addr = {NoneType} None
       master_port = {NoneType} None
       max_restarts = {int} 1
       monitor_interval = {int} 1
       rdzv_handler = {DynamicRendezvousHandler}
       redirects = {Std} Std.NONE
       role = {str} 'trainer'
       tee = {Std} Std.NONE
       entry = worker_fn
    

    代理会从这里提取各种所需信息。比如_start_workers 会从中获取 store。

    use_agent_store = spec.rdzv_handler.get_backend() == "static"
    

    此时逻辑为:

    +--------------------------+      +---------------------------------------------------+
    |LocalElasticAgent         |      | WorkerSpec                                        |
    |                          |      |                                                   |
    |     WorkerSpec +--------------> |      rdzv_handler = {DynamicRendezvousHandler} --------+
    |                          |      |                                                   |    |
    |     rdzv_run_id          |      |      entry = worker_fn                            |    |
    |                          |      |                                                   |    |
    |     store                |      |      role = {str} 'trainer'                       |    |
    |                          |      |                                                   |    |
    |                          |      +---------------------------------------------------+    |
    |                          |                                                               |
    |                          |                                                               |
    |                          |                                                               |
    |                          |                                                               |
    |                          |               +-----------------------------------------+     |
    +--------------------------+               |DynamicRendezvousHandler                 |     |
                                               |                                         |     |
                                               |                                         |     |
                                               |   _settings: RendezvousSettings         | <---+
                                               |                                         |
                                               |   _store: Store                         |
                                               |                                         |
                                               |   _state_holder: _RendezvousStateHolder |
                                               |                                         |
                                               |   _op_executor: _RendezvousOpExecutor   |
                                               |                                         |
                                               +-----------------------------------------+
    

    4.3.2 WorkerGroup

    WorkerGroup 代表了一个工作组。WorkerGroup 作为一个整体来管理多个 workers,进行批量处理。

    class WorkerGroup:
        """
        Represents the set of ``Worker`` instances for the given ``WorkerSpec``
        managed by ``ElasticAgent``. Whether the worker group contains cross
        instance workers or not depends on the implementation of the agent.
        """
    
        __slots__ = ["spec", "workers", "store", "group_rank", "group_world_size", "state"]
    
        def __init__(self, spec: WorkerSpec):
            self.spec = spec
            self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)]
    
            # assigned after rdzv
            self.store = None
            self.group_rank = None
            self.group_world_size = None
    
            self.state = WorkerState.INIT
    

    在SimpleElasticAgent 初始化之中,会建立一个 WorkerGroup。

    class SimpleElasticAgent(ElasticAgent):
        """
        An ``ElasticAgent`` that manages workers (``WorkerGroup``)
        for a single ``WorkerSpec`` (e.g. one particular type of worker role).
        """
    
        def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300):
            self._worker_group = WorkerGroup(spec)
            self._remaining_restarts = self._worker_group.spec.max_restarts
            self._store = None
            self._exit_barrier_timeout = exit_barrier_timeout
            self._total_execution_time = 0
    

    具体如下:

    +-----------------------------+      +------------------------------------------------+
    | LocalElasticAgent           |      | WorkerSpec                                     |
    |                             |      |                                                |
    | +------------------------+  |      |   rdzv_handler = {DynamicRendezvousHandler} -------+
    | |WorkerGroup             |  |      |                                                |   |
    | |            spec +--------------> |   entry = worker_fn                            |   |
    | |            workers     |  |      |                                                |   |
    | |            store       |  |      |   role = {str} 'trainer'                       |   |
    | |            group_rank  |  |      |                                                |   |
    | |       group_world_size |  |      +------------------------------------------------+   |
    | |                        |  |                                                           |
    | +------------------------+  |                                                           |
    |                             |                                                           |
    | rdzv_run_id                 |                                                           |
    | store                       |            +-----------------------------------------+    |
    |                             |            |DynamicRendezvousHandler                 |    |
    +-----------------------------+            |                                         |    |
                                               |                                         |    |
                                               |   _settings: RendezvousSettings         | <--+
                                               |                                         |
                                               |   _store: Store                         |
                                               |                                         |
                                               |   _state_holder: _RendezvousStateHolder |
                                               |                                         |
                                               |   _op_executor: _RendezvousOpExecutor   |
                                               |                                         |
                                               +-----------------------------------------+
    

    4.4 代理运行

    SimpleElasticAgent 是 LocalElasticAgent 的基类,所以会先运行到WorkerSpec.run 方法这里,run方法则调用了 _invoke_run。

        @prof
        def run(self, role: str = DEFAULT_ROLE) -> RunResult:
            start_time = time.monotonic()
            try:
                result = self._invoke_run(role) # 调用
                self._total_execution_time = int(time.monotonic() - start_time)
                self._record_metrics(result)
                self._record_worker_events(result)
                return result
            finally:
                # record the execution time in case there were any exceptions during run.
                self._total_execution_time = int(time.monotonic() - start_time)
                self._shutdown()
    

    4.5 代理主循环

    代理在 invoke_run 之中做如下操作:

    • 启动 _initialize_workers,这里会使用 _rendezvous 构建一个 rendezvous,然后调用 _start_workers 启动 workers。
    • 进入 while True 循环,在循环之中:
      • 通过 _monitor_workers 定期轮训用户程序运行情况,得到客户进程运行结果,然后依据情况作出判断。
        • 如果程序正常结束,则返回。
        • 如果程序出错,则重试,即重启所有 workers,如果重试次数达到依然有问题,就结束所有workers。
        • 如果节点成员关系有变化,比如scale up就会有新的节点在waiting,这时候就重启所有workers。
        def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
            # NOTE: currently only works for a single role
    
            spec = self._worker_group.spec
            role = spec.role
    
            self._initialize_workers(self._worker_group) # 启动worker
            monitor_interval = spec.monitor_interval
            rdzv_handler = spec.rdzv_handler
    
            while True:
                assert self._worker_group.state != WorkerState.INIT
                # 定期监控
                time.sleep(monitor_interval)
                # 监控客户程序运行情况
                run_result = self._monitor_workers(self._worker_group) # 得到进程运行结果
                state = run_result.state
                self._worker_group.state = state
    
                put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
                put_metric(f"workers.{role}.{state.name.lower()}", 1)
    
                if state == WorkerState.SUCCEEDED:
                    # 程序正常结束
                    self._exit_barrier()
                    return run_result
                elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
                    # 程序出错
                    if self._remaining_restarts > 0: # 重试
                        self._remaining_restarts -= 1
                        self._restart_workers(self._worker_group)
                    else:
                        self._stop_workers(self._worker_group) # 重试次数达到,结束workers
                        self._worker_group.state = WorkerState.FAILED
                        self._exit_barrier()
                        return run_result
                elif state == WorkerState.HEALTHY:
                    # 节点成员关系有变化,比如scale up,就会有新节点waiting
                    # membership changes do not count as retries
                    num_nodes_waiting = rdzv_handler.num_nodes_waiting()
                    group_rank = self._worker_group.group_rank
                    # 如果有新的节点在waiting,就重启所有workers
                    if num_nodes_waiting > 0:
                        self._restart_workers(self._worker_group)
                else:
                    raise Exception(f"[{role}] Worker group in {state.name} state")
    
    

    于是最终逻辑如下:

    +----------------------------------------------+
    | LocalElasticAgent                            |
    |                                              |    +---------------------------------------------------+
    |  rdzv_run_id                                 |    | WorkerSpec                                        |
    |                                              |    |                                                   |
    |  store           +------------------------+  |    |      rdzv_handler = {DynamicRendezvousHandler} +-------+
    |                  |WorkerGroup             |  |    |                                                   |    |
    |  _pcontext       |            spec +------------> |      entry = worker_fn                            |    |
    |                  |            workers     |  |    |                                                   |    |
    |                  |            store       |  |    |      role = {str} 'trainer'                       |    |
    |                  |            group_rank  |  |    |                                                   |    |
    |                  |       group_world_size |  |    +---------------------------------------------------+    |
    |                  |                        |  |                                                             |
    |                  +------------------------+  |                                                             |
    |  +----------------------------------------+  |                                                             |
    |  | _invoke_run                            |  |                                                             |
    |  |                                        |  |             +-----------------------------------------+     |
    |  |   _initialize_workers +------------------------+        |DynamicRendezvousHandler                 |     |
    |  |                                        |  |    |        |                                         |     |
    |  |                                        |  |    |        |                                         |     |
    |  |   while True:                          |  |    |        |   _settings: RendezvousSettings         | <---+
    |  |       _monitor_workers(_worker_group)  |  |    |        |                                         |
    |  |                +                       |  |    |        |   _store: Store                         |
    |  |                | _pcontext.wait        |  |    |        |                                         |
    |  |                |                       |  |    |        |   _state_holder: _RendezvousStateHolder |
    |  +----------------------------------------+  |    |        |                                         |
    |                   |                          |    |        |   _op_executor: _RendezvousOpExecutor   |
    +----------------------------------------------+    |        |                                         |
                        |                               |        +-----------------------------------------+
                        |                               |
                        v                               v
             +-------------------------------------------------+
             |  +------------+  +------------+  +------------+ |
             |  |Process     |  |Process     |  |Process     | |
             |  |            |  |            |  |            | |
             |  |    work_fn |  |   work_fn  |  |    work_fn | |
             |  |            |  |            |  |            | |
             |  +------------+  +------------+  +------------+ |
             +-------------------------------------------------+
    
    

    手机如下:

    至此,脚本如何启动和单体流程我们分析完毕,下一篇我们来具体分析代理。

    0xFF 参考

    [PyTorch Elastic源码阅读](

  • 相关阅读:
    python项目_mysql开启事务
    python项目_ImageField字段
    linux基础_常用命令
    mysql数据_查询操作
    list 和 tuple——python基础学习
    python-格式化
    python-字符串
    数学——变上限积分的应用
    python-交互模式
    蓝桥杯——汉诺塔问题
  • 原文地址:https://www.cnblogs.com/rossiXYZ/p/15725911.html
Copyright © 2020-2023  润新知