分布式数据并行入门
原文: https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
作者:申力
DistributedDataParallel (DDP)在模块级别实现数据并行性。 它使用 Torch.distributed 程序包中的通信集合来同步梯度,参数和缓冲区。 并行性在流程内和跨流程均可用。 在一个过程中,DDP 将输入模块复制到device_ids
中指定的设备,将输入沿批次维度分散,然后将输出收集到output_device
,这与 DataParallel 相似。 在整个过程中,DDP 在正向传递中插入必要的参数同步,在反向传递中插入梯度同步。 用户可以将进程映射到可用资源,只要进程不共享 GPU 设备即可。 推荐的方法(通常是最快的方法)是为每个模块副本创建一个过程,即在一个过程中不进行任何模块复制。 本教程中的代码在 8-GPU 服务器上运行,但可以轻松地推广到其他环境。
DataParallel
和DistributedDataParallel
之间的比较
在深入探讨之前,让我们澄清一下为什么尽管增加了复杂性,但还是考虑使用DistributedDataParallel
而不是DataParallel
:
- 首先,请回顾先前的教程,如果模型太大而无法容纳在单个 GPU 上,则必须使用模型并行将其拆分到多个 GPU 中。
DistributedDataParallel
与模型并行一起使用;DataParallel
目前没有。 DataParallel
是单进程,多线程,并且只能在单台机器上运行,而DistributedDataParallel
是多进程,并且适用于单机和多机训练。 因此,即使在单机训练中,数据足够小以适合单机,DistributedDataParallel
仍比DataParallel
快。DistributedDataParallel
还预先复制模型,而不是在每次迭代时复制模型,并避免了全局解释器锁定。- 如果您的两个数据都太大而无法容纳在一台计算机和上,而您的模型又太大了以至于无法安装在单个 GPU 上,则可以将模型并行(跨多个 GPU 拆分单个模型)与
DistributedDataParallel
结合使用。 在这种情况下,每个DistributedDataParallel
进程都可以并行使用模型,而所有进程都将并行使用数据。
基本用例
要创建 DDP 模块,请首先正确设置过程组。 更多细节可以在用 PyTorch 编写分布式应用程序中找到。
import os
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# Explicitly setting seed to make sure that models created in two processes
# start from same random weights and biases.
torch.manual_seed(42)
def cleanup():
dist.destroy_process_group()
现在,让我们创建一个玩具模块,将其与 DDP 封装在一起,并提供一些虚拟输入数据。 请注意,由于DDP将0级进程中的模型状态广播到DDP构造函数中的所有其他进程,因此无需担心不同的DDP进程从不同的模型参数初始值开始。
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank, world_size):
setup(rank, world_size)
# setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
# rank 2 uses GPUs [4, 5, 6, 7].
n = torch.cuda.device_count() // world_size
device_ids = list(range(rank * n, (rank + 1) * n))
# create model and move it to device_ids[0]
model = ToyModel().to(device_ids[0])
# output_device defaults to device_ids[0]
ddp_model = DDP(model, device_ids=device_ids)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(device_ids[0])
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
如您所见,DDP 包装了较低级别的分布式通信详细信息,并提供了干净的 API,就好像它是本地模型一样。 对于基本用例,DDP 仅需要几个 LoC 来设置流程组。 在将 DDP 应用到更高级的用例时,需要注意一些警告。
偏斜的处理速度
在 DDP 中,构造函数,转发方法和输出的微分是分布式同步点。 期望不同的过程以相同的顺序到达同步点,并在大致相同的时间进入每个同步点。 否则,快速流程可能会提早到达,并在等待流浪者时超时。 因此,用户负责平衡流程之间的工作负载分配。 有时,由于例如网络延迟,资源争用,不可预测的工作量峰值,不可避免地会出现偏斜的处理速度。 为了避免在这些情况下超时,请在调用 init_process_group 时传递足够大的timeout
值。
保存和加载检查点
在训练过程中通常使用torch.save
和torch.load
来检查点模块并从检查点中恢复。 有关更多详细信息,请参见保存和加载模型。 使用 DDP 时,一种优化方法是仅在一个进程中保存模型,然后将其加载到所有进程中,从而减少写开销。 这是正确的,因为所有过程都从相同的参数开始,并且梯度在向后传递中同步,因此优化程序应将参数设置为相同的值。 如果使用此优化,请确保在保存完成之前不要启动所有进程。 此外,在加载模块时,您需要提供适当的map_location
参数,以防止进程进入其他设备。 如果缺少map_location
,则torch.load
将首先将该模块加载到 CPU,然后将每个参数复制到其保存位置,这将导致同一台机器上的所有进程使用同一组设备。
def demo_checkpoint(rank, world_size):
setup(rank, world_size)
# setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
# rank 2 uses GPUs [4, 5, 6, 7].
n = torch.cuda.device_count() // world_size
device_ids = list(range(rank * n, (rank + 1) * n))
model = ToyModel().to(device_ids[0])
# output_device defaults to device_ids[0]
ddp_model = DDP(model, device_ids=device_ids)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
if rank == 0:
# All processes should see same parameters as they all start from same
# random parameters and gradients are synchronized in backward passes.
# Therefore, saving it in one process is sufficient.
torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)
# Use a barrier() to make sure that process 1 loads the model after process
# 0 saves it.
dist.barrier()
# configure map_location properly
rank0_devices = [x - rank * len(device_ids) for x in device_ids]
device_pairs = zip(rank0_devices, device_ids)
map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs}
ddp_model.load_state_dict(
torch.load(CHECKPOINT_PATH, map_location=map_location))
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(device_ids[0])
loss_fn = nn.MSELoss()
loss_fn(outputs, labels).backward()
optimizer.step()
# Use a barrier() to make sure that all processes have finished reading the
# checkpoint
dist.barrier()
if rank == 0:
os.remove(CHECKPOINT_PATH)
cleanup()
将 DDP 与模型并行性结合
DDP 还可以与多 GPU 模型一起使用,但是不支持进程内的复制。 您需要为每个模块副本创建一个进程,与每个进程的多个副本相比,通常可以提高性能。 当训练具有大量数据的大型模型时,DDP 包装多 GPU 模型特别有用。 使用此功能时,需要小心地实现多 GPU 模型,以避免使用硬编码的设备,因为会将不同的模型副本放置到不同的设备上。
class ToyMpModel(nn.Module):
def __init__(self, dev0, dev1):
super(ToyMpModel, self).__init__()
self.dev0 = dev0
self.dev1 = dev1
self.net1 = torch.nn.Linear(10, 10).to(dev0)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(10, 5).to(dev1)
def forward(self, x):
x = x.to(self.dev0)
x = self.relu(self.net1(x))
x = x.to(self.dev1)
return self.net2(x)
将多 GPU 模型传递给 DDP 时,不得设置device_ids
和output_device
。 输入和输出数据将通过应用程序或模型forward()
方法放置在适当的设备中。
def demo_model_parallel(rank, world_size):
setup(rank, world_size)
# setup mp_model and devices for this process
dev0 = rank * 2
dev1 = rank * 2 + 1
mp_model = ToyMpModel(dev0, dev1)
ddp_mp_model = DDP(mp_model)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)
optimizer.zero_grad()
# outputs will be on dev1
outputs = ddp_mp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(dev1)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
if __name__ == "__main__":
run_demo(demo_basic, 2)
run_demo(demo_checkpoint, 2)
if torch.cuda.device_count() >= 8:
run_demo(demo_model_parallel, 4)