PYTORCH并行训练。Author: Shen Li
DistributedDataParallel (DDP) 分布式训练在模型层面实现数据并行。利用 torch.distributed包来同步梯度、参数和缓存。并行性在进程内和进程间都是可用的。在进程中,DDP将输入模块复制到device_ids中指定的设备,相应地沿批处理维度分散输入,并将输出收集到output_device,类似于DataParallel.。在进程之间,DDP在前向过程中插入必要的参数同步,在反向过程中插入梯度同步 。只要进程不共享GPU设备,用户就可以将进程映射到可用资源。推荐的(通常是最快的)方法是为每个模块副本创建一个进程,即进程中没有模块复制。本教程中的代码运行在8-GPU服务器上,但它可以很容易地推广到其他环境中。
Comparison between DataParallel
and DistributedDataParallel
在学习之前先弄明白为啥尽管操作复杂,你还是会考虑利用DistributedDataParallel
而非DataParallel
:
- 首先回顾前文,如果模型太大塞不到一个GPU里面就必须用model parallel来划分为多个GPU上,而DistributedDataParallel在model parallel中是work的,这次DataParallel这次不然。
- DataParallel是单进程,多线程。仅仅在一台机器上工作,而
DistributedDataParallel
是多进程并可同时在单机或多机中训练,因此,即使对于单机训练当数据足够小到可以放到一台机器里,DistributedDataParallel
也应当比DataParallel快。DistributedDataParallel
也将模型复制,而不是每次迭代,并使全局解释器锁定。 - 如果你的数据太大放不到一台机器并且模型太大放不到一个GPU,则可以结合model parallel(将单个模型分解到不同GPU上)与DistributedDataParallel。这时每个
DistributedDataParallel
进程会利用model parallel,并且所有的进程都爱能够使用dara parallel。
Basic Use Case
为创建DDP module,首先建立进程组,更多细节可以在:Writing Distributed Applications with PyTorch.
import os import tempfile import torch import torch.distributed as dist import torch.nn as nn import torch.optim as optim import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP def setup(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' # initialize the process group dist.init_process_group("gloo", rank=rank, world_size=world_size) # Explicitly setting seed to make sure that models created in two processes # start from same random weights and biases. torch.manual_seed(42) def cleanup(): dist.destroy_process_group()
然后简单实验,利用DDP来处理,随即输入一些数据。注意的是如果模型由随机参数初始化,那么需要保证所有DDP进程有相同的初始值,不然 global gradient synchronizes也没意义。
class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() self.net1 = nn.Linear(10, 10) self.relu = nn.ReLU() self.net2 = nn.Linear(10, 5) def forward(self, x): return self.net2(self.relu(self.net1(x))) def demo_basic(rank, world_size): setup(rank, world_size) # setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and # rank 2 uses GPUs [4, 5, 6, 7]. n = torch.cuda.device_count() // world_size device_ids = list(range(rank * n, (rank + 1) * n)) # create model and move it to device_ids[0] model = ToyModel().to(device_ids[0]) # output_device defaults to device_ids[0] ddp_model = DDP(model, device_ids=device_ids) loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) optimizer.zero_grad() outputs = ddp_model(torch.randn(20, 10)) labels = torch.randn(20, 5).to(device_ids[0]) loss_fn(outputs, labels).backward() optimizer.step() cleanup() def run_demo(demo_fn, world_size): mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True)
DDP包装了较低级别的分布式通信细节,并提供了一个干净的API,就好像它是一个本地模型一样。对于基本的应用,DDP仅需要一些LoCs来建立进程组。当应用DDP到更复杂的场景时。有一些需要注意的注意事项。
Skewed Processing Speeds
在DDP中,构造函数、前向方法和输出的微分是分布式同步点。不同的进程将以相同的顺序到达同步点,并在大致相同的时间进入每个同步点。否则,快速进程可能会提前到达,并在等待散乱的进程时超时。因此,用户负责跨流程平衡工作负载分布。有时,由于网络延迟、资源争用、不可预测的工作负载高峰等原因,处理速度出现偏差是不可避免的。为了避免在这种情况下超时,确保你在调用 init_process_group时传递了足够大的timeout值。
Save and Load Checkpoints
普遍利用torch.savehetorch.load来保存检查点和载入检查点。 SAVING AND LOADING MODELS里面有更多细节。当利用DDP时,一次优化保存模型到仅仅一个进程,并载入到所有进程减少写开销。这是正确的,因为所有进程都是从相同的参数开始的,并且梯度在反向过程中是同步的,因此优化器应该将参数设置为相同的值。如果使用此优化,请确保在保存完成之前,所有进程都不会开始加载。此外当载入模型时。需要体哦国内合适的map_location参数来避免进程进入其他设备。如果该参数缺失,torch.load将首先载入模型到GPU并复制其他参数到保存的地方,将导致同一台计算机上的所有进程使用同一组device。
def demo_checkpoint(rank, world_size): setup(rank, world_size) # setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and # rank 2 uses GPUs [4, 5, 6, 7]. n = torch.cuda.device_count() // world_size device_ids = list(range(rank * n, (rank + 1) * n)) model = ToyModel().to(device_ids[0]) # output_device defaults to device_ids[0] ddp_model = DDP(model, device_ids=device_ids) loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint" if rank == 0: # All processes should see same parameters as they all start from same # random parameters and gradients are synchronized in backward passes. # Therefore, saving it in one process is sufficient. torch.save(ddp_model.state_dict(), CHECKPOINT_PATH) # Use a barrier() to make sure that process 1 loads the model after process # 0 saves it. dist.barrier() # configure map_location properly rank0_devices = [x - rank * len(device_ids) for x in device_ids] device_pairs = zip(rank0_devices, device_ids) map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs} ddp_model.load_state_dict( torch.load(CHECKPOINT_PATH, map_location=map_location)) optimizer.zero_grad() outputs = ddp_model(torch.randn(20, 10)) labels = torch.randn(20, 5).to(device_ids[0]) loss_fn = nn.MSELoss() loss_fn(outputs, labels).backward() optimizer.step() # Use a barrier() to make sure that all processes have finished reading the # checkpoint dist.barrier() if rank == 0: os.remove(CHECKPOINT_PATH) cleanup()
Combine DDP with Model Parallelism
DDP与multi-gpu模型可一起工作,但是在一个进程中复制是不支持的。你需要对每个模型拷贝创建一个进程,这通常比一个进程的多个拷贝要有更好的表现。DDP与multi-gpu模型一起用在大量数据且大模型训练是很有用的。利用这个特点,multi-gpu模型需要小心实现避免hard-coded devices,因为不同的模型复制后将置于不同的device。
class ToyMpModel(nn.Module): def __init__(self, dev0, dev1): super(ToyMpModel, self).__init__() self.dev0 = dev0 self.dev1 = dev1 self.net1 = torch.nn.Linear(10, 10).to(dev0) self.relu = torch.nn.ReLU() self.net2 = torch.nn.Linear(10, 5).to(dev1) def forward(self, x): x = x.to(self.dev0) x = self.relu(self.net1(x)) x = x.to(self.dev1) return self.net2(x)
当传递multi-gpu到DDP时,device_ids和output_device不应被set。输入和输出数据将在合适的设备上放置。
def demo_model_parallel(rank, world_size): setup(rank, world_size) # setup mp_model and devices for this process dev0 = rank * 2 dev1 = rank * 2 + 1 mp_model = ToyMpModel(dev0, dev1) ddp_mp_model = DDP(mp_model) loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001) optimizer.zero_grad() # outputs will be on dev1 outputs = ddp_mp_model(torch.randn(20, 10)) labels = torch.randn(20, 5).to(dev1) loss_fn(outputs, labels).backward() optimizer.step() cleanup() if __name__ == "__main__": run_demo(demo_basic, 2) run_demo(demo_checkpoint, 2) if torch.cuda.device_count() >= 8: run_demo(demo_model_parallel, 4)