一个简单的例子。
注意:
os.environ['MASTER_ADDR'] = 'xxx.xx.xx.xxx' # 这里填写电脑的IP地址 os.environ['MASTER_PORT'] = '29555' # 空闲端口
这两个参数似乎必须提前给出,选择的初始化方法为init_method="env://"(默认的环境变量方法)
# 单机多卡并行计算示例 import os import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim from torch.nn.parallel import DistributedDataParallel as DDP # https://pytorch.org/docs/stable/notes/ddp.html def example(local_rank, world_size): # local_rank由mp.spawn自动给出 # create default process group dist.init_process_group(backend="gloo", init_method="env://", rank=local_rank, world_size=world_size) # create local model model = nn.Linear(10, 10).cuda(local_rank) # construct DDP model ddp_model = DDP(model, device_ids=[local_rank], output_device=local_rank) # define loss function and optimizer loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) # forward pass for i in range(100): if local_rank == 0: # 这里开几个进程就会打印几次 print(i) outputs = ddp_model(torch.randn(20, 10).cuda(local_rank)) labels = torch.randn(20, 10).cuda(local_rank) # backward pass loss_fn(outputs, labels).backward() # update parameters optimizer.step() def main(): os.environ['MASTER_ADDR'] = 'xxx.xx.xx.xxx' # 这里填写电脑的IP地址 os.environ['MASTER_PORT'] = '29555' # 空闲端口 world_size = torch.cuda.device_count() mp.spawn(example, args=(world_size,), nprocs=world_size, join=True) if __name__=="__main__": main() print('Done!')