Distributed training jobs can resume after failures by periodically saving their state to a shared location, allowing a new training run to pick up exactly where the previous one left off.

Let’s see this in action. Imagine you’re training a massive language model across 16 GPUs. A single GPU fails, or the network connection drops. Without checkpoints, you’d lose hours, maybe days, of training. With them, you can restart from the last saved point.

Here’s a simplified PyTorch example using torch.distributed.fsdp (Fully Sharded Data Parallelism), which is a common way to handle large models:

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
import torch.nn as nn
from torch.optim import Adam
import os

# Assume dist.init_process_group has already been called

# Define your model
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1024, 1024)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(1024, 10)

    def forward(self, x):
        return self.layer2(self.relu(self.layer1(x)))

model = MyModel()

# Wrap the model with FSDP
# This is a simplified example; real-world FSDP setup is more involved
fsdp_model = FSDP(model)

# Optimizer
optimizer = Adam(fsdp_model.parameters(), lr=1e-3)

# Checkpoint directory
checkpoint_dir = "/mnt/shared_storage/my_training_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# --- Resuming Logic ---
start_epoch = 0
resume_from_checkpoint = True # Set to False to start from scratch

if resume_from_checkpoint and dist.get_rank() == 0: # Only rank 0 checks for checkpoints
    latest_checkpoint_path = os.path.join(checkpoint_dir, "latest.pt")
    if os.path.exists(latest_checkpoint_path):
        print(f"Attempting to resume from {latest_checkpoint_path}")
        try:
            # Load checkpoint
            checkpoint = torch.load(latest_checkpoint_path, map_location="cpu")

            # Load model state, optimizer state, and epoch number
            fsdp_model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            start_epoch = checkpoint['epoch']
            print(f"Resumed successfully. Starting from epoch {start_epoch}")
        except Exception as e:
            print(f"Error loading checkpoint: {e}. Starting from scratch.")
            resume_from_checkpoint = False # Fallback to starting from scratch
    else:
        print("No checkpoint found. Starting from scratch.")
        resume_from_checkpoint = False

# Broadcast the resume status to all ranks
resume_status = torch.tensor(int(resume_from_checkpoint and os.path.exists(latest_checkpoint_path)))
dist.broadcast(resume_status, src=0)

if resume_status.item() == 0:
    start_epoch = 0 # Ensure all ranks start from epoch 0 if rank 0 failed to load

# --- Training Loop ---
num_epochs = 10
batch_size = 32
input_features = 1024

for epoch in range(start_epoch, num_epochs):
    # Simulate data loading and processing
    # In a real scenario, this would involve distributed data loading
    inputs = torch.randn(batch_size, input_features).cuda()
    labels = torch.randint(0, 10, (batch_size,)).cuda()

    # Forward pass
    outputs = fsdp_model(inputs)
    loss_fn = nn.CrossEntropyLoss()
    loss = loss_fn(outputs, labels)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch}, Loss: {loss.item()}")

    # --- Checkpointing Logic ---
    if epoch % 2 == 0: # Save every 2 epochs
        if dist.get_rank() == 0:
            print(f"Saving checkpoint for epoch {epoch}...")
            # FSDP requires saving the consolidated state
            # This is a simplified view; FSDP's save/load is more nuanced
            torch.save({
                'epoch': epoch + 1, # Save the next epoch to start from
                'model_state_dict': fsdp_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, os.path.join(checkpoint_dir, f"epoch_{epoch}.pt"))
            # Save a 'latest' symlink or file for easy resuming
            latest_checkpoint_path = os.path.join(checkpoint_dir, "latest.pt")
            with open(latest_checkpoint_path, "w") as f:
                f.write(os.path.join(checkpoint_dir, f"epoch_{epoch}.pt"))
            print("Checkpoint saved.")

# After training, you might want to save a final model
if dist.get_rank() == 0:
    torch.save(fsdp_model.module.state_dict(), os.path.join(checkpoint_dir, "final_model.pt"))

The core idea is to capture the entire state of your distributed computation at a given point. This includes:

  1. Model weights: The parameters of your neural network.
  2. Optimizer state: Crucially, optimizers like Adam or SGD maintain internal states (e.g., momentum buffers, variance estimates) that are essential for resuming training correctly. If you only save model weights, your optimizer’s state will be reset, leading to a drastically different training trajectory.
  3. Epoch/Iteration number: To know where to resume from in the training loop.
  4. Random number generator states: If your training involves stochasticity (e.g., data augmentation, dropout) that you want to be reproducible across restarts, you might need to save and restore RNG states for PyTorch, NumPy, and potentially even Python’s random module.
  5. Distributed state: For complex distributed setups, there might be specific distributed communication states that need to be preserved. FSDP handles much of this internally during its load_state_dict and state_dict calls.

The "shared location" is critical. This must be accessible by all ranks that might participate in the resumed job. This typically means a network file system (NFS), cloud object storage (S3, GCS), or a distributed file system (Ceph, GlusterFS). Local disk on a single node won’t work if that node is the one that failed.

The mental model is that your distributed job is a single, albeit distributed, process. When it fails, you’re essentially killing that process. To restart, you’re launching a new set of processes. These new processes need to be initialized not from scratch, but from a snapshot that represents the state of the old process just before it died.

The most surprising thing most people miss is that saving just the model’s state_dict() is often insufficient. The optimizer’s internal state, which represents things like the exponentially decaying average of past gradients (for Adam) or momentum, is just as vital for a seamless resume. Without it, the optimizer acts as if it’s starting fresh, throwing away all the accumulated knowledge about gradient variance and momentum, which can significantly alter training convergence.

The next problem you’ll hit is managing checkpoint storage and cleanup as your training runs for weeks or months.

Want structured learning?

Take the full AI Infrastructure course →