PyTorch’s Distributed Data Parallel (DDP) is designed to scale training across multiple GPUs, but it’s not just about throwing more hardware at the problem; it’s about efficiently coordinating gradient synchronization and data loading.
Let’s see DDP in action. Imagine we have two GPUs and a simple model.
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
import os
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 2)
def forward(self, x):
return self.linear(x)
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def train_process(rank, world_size):
setup(rank, world_size)
model = SimpleModel().to(rank)
ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.01)
# Dummy data
data = torch.randn(16, 10).to(rank)
target = torch.randn(16, 2).to(rank)
for epoch in range(5):
optimizer.zero_grad()
outputs = ddp_model(data)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
if rank == 0:
print(f"Epoch {epoch}, Loss: {loss.item()}")
cleanup()
if __name__ == "__main__":
world_size = 2
mp.spawn(train_process, args=(world_size,), nprocs=world_size, join=True)
To run this, save it as ddp_example.py and execute: python -m torch.distributed.launch --nproc_per_node=2 ddp_example.py. You’ll see output from rank 0 showing the loss decreasing.
DDP solves the problem of training large models that don’t fit on a single GPU, or speeding up training by using multiple GPUs to process different mini-batches concurrently. It achieves this by wrapping your model in DistributedDataParallel. This wrapper automatically handles gradient synchronization. During the backward pass, gradients are computed for each GPU’s data subset. DDP then averages these gradients across all GPUs before the optimizer step. This ensures that each GPU’s model parameters remain synchronized, effectively training a single model.
The core components you control are world_size, rank, init_process_group, and the DDP wrapper itself. world_size is the total number of processes (and typically GPUs) involved. rank is the unique identifier for each process, from 0 to world_size - 1. init_process_group sets up the communication backend (like nccl for NVIDIA GPUs) and establishes connections between all processes. The DistributedDataParallel wrapper is the engine that orchestrates the gradient averaging. For data loading, you’ll typically use torch.utils.data.distributed.DistributedSampler to ensure each process gets a unique, non-overlapping subset of your dataset.
The "backend" in torch.distributed.init_process_group(backend="nccl") is crucial. While nccl is the go-to for NVIDIA GPUs due to its speed and efficiency in collective operations, other backends like gloo (CPU-based, good for CPU or multi-node CPU training) or mpi (for MPI environments) exist. Choosing the right backend can significantly impact performance, especially in terms of inter-GPU communication latency. nccl uses NVIDIA’s own collective communication library, optimized for GPU-to-GPU communication, making it the default and most performant choice for multi-GPU setups.
The most surprising thing about DDP’s gradient synchronization is that it doesn’t require a central server to aggregate gradients. Instead, it uses efficient all-reduce operations. Each GPU computes its gradients, and then nccl (or the chosen backend) directly communicates these gradients to all other GPUs in a way that allows each GPU to end up with the averaged gradient without a bottleneck.
Once you have DDP working seamlessly on a single machine with multiple GPUs, the next logical step is scaling to multiple machines, which introduces network latency and requires careful management of the MASTER_ADDR and MASTER_PORT across all nodes.