Activation checkpointing lets you trade compute for memory, allowing you to train larger models on less GPU VRAM.

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1024, 1024)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(1024, 1024)
        self.layer3 = nn.Linear(1024, 10)

    def forward(self, x):
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        x = self.relu(x)
        x = self.layer3(x)
        return x

# Without checkpointing
model_no_checkpoint = MyModel()
input_tensor = torch.randn(32, 1024) # Batch size 32, input dim 1024
output_no_checkpoint = model_no_checkpoint(input_tensor)
print(f"Memory usage without checkpointing: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")

# With checkpointing
model_checkpoint = MyModel()
model_checkpoint.layer2 = checkpoint(model_checkpoint.layer2) # Checkpoint layer2
model_checkpoint.layer3 = checkpoint(model_checkpoint.layer3) # Checkpoint layer3

# Reset CUDA memory stats
torch.cuda.reset_peak_memory_stats()
output_checkpoint = model_checkpoint(input_tensor)
print(f"Memory usage with checkpointing: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")

This problem arises when your model’s intermediate activations, which are stored during the forward pass to enable gradient calculation in the backward pass, consume more GPU memory than available. For large models or large batch sizes, this can easily lead to an "out of memory" (OOM) error. Activation checkpointing addresses this by selectively discarding intermediate activations during the forward pass and recomputing them during the backward pass. This trades increased computation time for significantly reduced memory usage.

The core idea is that instead of storing every single activation, you only store the inputs to certain "checkpointed" modules. During the backward pass, when the gradients for a checkpointed module are needed, PyTorch re-runs the forward pass just for that module, using the stored input, to compute the necessary intermediate activations for backpropagation. This means the memory footprint during the forward pass is reduced because you’re not holding onto all activations simultaneously.

You can apply torch.utils.checkpoint.checkpoint to any nn.Module or even a function. The function checkpoint takes the module or function as its first argument, followed by the inputs to that module/function.

To implement it, you wrap the modules you want to checkpoint with torch.utils.checkpoint.checkpoint. You don’t need to modify the model architecture itself; you can do this dynamically. For instance, if you have model.block1, model.block2, and model.block3 and want to checkpoint block2, you’d replace it with model.block2 = checkpoint(model.block2).

The checkpoint function has an optional preserve_rng_state argument, which defaults to True. If you’re using data parallelism or distributed training and have stochastic layers (like dropout) within your checkpointed modules, setting preserve_rng_state=False can be crucial to ensure that the random number generator state is consistent across replicas or processes for the recomputed forward passes. However, be mindful that this can lead to different dropout masks being applied during the recomputation, potentially affecting convergence.

When you apply checkpoint to a module, PyTorch effectively replaces its forward method with a custom one. This custom method first saves the input tensors to the original forward method and then calls the original forward method. Crucially, it doesn’t save the output of the original forward method as a persistent tensor for backpropagation. Instead, it registers a backward hook. This hook is responsible for re-executing the original forward method (using the saved input tensors) when the gradient is required, and then computing the gradients from that re-executed forward pass.

A common misconception is that checkpointing only works for entire layers. You can checkpoint arbitrary functions or sequences of operations. For example, if you have a complex block of operations within a single nn.Module, you could define a separate function that performs these operations and then checkpoint that function directly, rather than checkpointing the entire nn.Module. This gives you finer-grained control over which parts of your computation’s memory are being saved.

The most significant trade-off is the increased training time. Because parts of the forward pass are recomputed during the backward pass, the overall execution time per iteration increases. The exact overhead depends on the complexity of the checkpointed modules and the number of checkpointed modules. A good heuristic is to checkpoint larger, memory-intensive blocks of your model.

Once you’ve successfully reduced memory usage with activation checkpointing and avoided OOM errors, your next challenge will likely be optimizing the training speed, as checkpointing inherently slows down the training process.

Want structured learning?

Take the full Gpu course →