Distributed training is less about making a single GPU do more work and more about making many GPUs coordinate their efforts to solve a single, massive problem.
Let’s see this in action with a simple PyTorch DistributedDataParallel example. Imagine we have two GPUs, cuda:0 and cuda:1, and we want to train a small model on some dummy data.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import os
# --- Configuration ---
WORLD_SIZE = 2 # Number of GPUs
RANK = int(os.environ['RANK']) # Current GPU's rank (0 or 1)
MASTER_ADDR = 'localhost'
MASTER_PORT = '12355'
BACKEND = 'nccl' # For NVIDIA GPUs
# --- Initialization ---
dist.init_process_group(backend=BACKEND, rank=RANK, world_size=WORLD_SIZE,
init_method=f'tcp://{MASTER_ADDR}:{MASTER_PORT}')
# --- Model Definition ---
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 2) # Input size 10, output size 2
def forward(self, x):
return self.fc(x)
# --- Data and Model Setup ---
model = SimpleModel().to(f'cuda:{RANK}')
# Wrap the model with DistributedDataParallel
# device_ids should be a list containing the current GPU's ID
ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[RANK])
# Dummy data
# Each process will have a subset of data or generate its own
# For simplicity, we generate identical data here, but in real scenarios,
# you'd use DistributedSampler to ensure unique data per process.
data = torch.randn(64, 10).to(f'cuda:{RANK}') # Batch size 64, features 10
labels = torch.randint(0, 2, (64,)).to(f'cuda:{RANK}')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)
# --- Training Loop ---
for epoch in range(2): # Train for 2 epochs
optimizer.zero_grad()
outputs = ddp_model(data)
loss = criterion(outputs, labels)
loss.backward() # Gradients are automatically reduced across processes
optimizer.step()
if RANK == 0: # Print loss only from the main process
print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')
# --- Cleanup ---
dist.destroy_process_group()
To run this, you’d typically use torchrun (or torch.distributed.launch for older PyTorch versions):
# Save the code above as train_ddp.py
# Run this command in your terminal
torchrun --nproc_per_node=2 train_ddp.py
torchrun handles setting the RANK environment variable for each process, making sure cuda:0 knows it’s rank 0 and cuda:1 knows it’s rank 1. The init_process_group call establishes communication channels between these ranks using the specified backend (NCCL is the de facto standard for NVIDIA GPUs). DistributedDataParallel then wraps your model. During the backward pass, PyTorch automatically synchronizes gradients across all participating GPUs. This ensures that each GPU’s model state remains consistent, as if it were trained on the entire dataset. The optimizer.step() then updates the weights using these averaged gradients.
The core problem distributed training solves is that a single GPU, no matter how powerful, has finite memory and processing capacity. For truly massive datasets and complex models (like large language models or high-resolution image generators), a single GPU simply can’t hold the model, the data, or complete the training in a reasonable timeframe. Distributed training breaks this bottleneck by parallelizing the computation across multiple devices.
There are two primary strategies:
- Data Parallelism: This is what
DistributedDataParallel(DDP) implements. The model is replicated on each GPU, and each GPU processes a different shard of the data. Gradients are computed locally and then averaged across all GPUs before updating the model weights. This is excellent for speeding up training when the model fits on a single GPU but the dataset is too large. - Model Parallelism: Here, the model itself is split across multiple GPUs. Different layers or parts of the model reside on different devices. Data flows sequentially through these devices. This is necessary when the model is too large to fit into the memory of a single GPU. Frameworks like DeepSpeed or Megatron-LM offer sophisticated ways to manage model parallelism.
In practice, complex training often involves a hybrid approach, combining data and model parallelism to tackle both massive datasets and enormous models. Frameworks like PyTorch Lightning, TensorFlow Keras, and specialized libraries like DeepSpeed and Horovod abstract away much of the complexity, allowing you to focus on the model and data.
The communication backend is crucial. For NVIDIA GPUs, nccl (NVIDIA Collective Communications Library) is highly optimized for multi-GPU communication over NVLink or PCIe. For CPUs or other accelerators, other backends like gloo or mpi might be used. nccl is generally the fastest and most efficient for GPU-to-GPU communication.
When loss.backward() is called in a DDP setup, it doesn’t just compute gradients for the local model. It triggers a gradient reduction operation. By default, this is an all-reduce operation where each GPU computes its local gradients, then all GPUs exchange their gradients, and each GPU ends up with the sum (or average, depending on configuration) of gradients from all processes. This is what keeps the model weights synchronized across all replicas.
The init_method parameter in dist.init_process_group is how processes find each other to establish the communication group. tcp://localhost:12355 means they’ll use TCP on localhost port 12355 to coordinate. For multi-node training, MASTER_ADDR would be the IP address of the main node, and MASTER_PORT a port open on that node.
The device_ids argument in DistributedDataParallel(model, device_ids=[RANK]) is vital. It tells DDP which GPU this specific process is responsible for. Each process should only be assigned to one GPU to avoid conflicts.
The most common pitfall is incorrect environment variable setup or communication backend issues. If RANK isn’t set correctly for each process, or if MASTER_ADDR/MASTER_PORT are wrong, processes won’t be able to find each other, leading to hangs or errors during init_process_group. Always ensure the WORLD_SIZE matches the number of processes you’re actually launching.
The next challenge you’ll face is efficiently loading and distributing large datasets without bottlenecks, often solved by using torch.utils.data.distributed.DistributedSampler and optimized data loading pipelines.