Gradient checkpointing lets you trade computation for memory when training large neural networks.

Let’s see it in action. Imagine a simple, deep network where each layer’s output is a tensor. For a forward pass, we compute y = layer_n(layer_{n-1}(...(layer_1(x)))). To compute gradients during the backward pass, we need to recompute intermediate values. Normally, the full computation graph is kept in memory.

import torch
import torch.nn as nn

class DeepNet(nn.Module):
    def __init__(self, num_layers, hidden_dim):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers)
        ])
        self.relu = nn.ReLU()

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
            x = self.relu(x)
        return x

# Example usage
num_layers = 100
hidden_dim = 1024
model = DeepNet(num_layers, hidden_dim)
input_tensor = torch.randn(32, hidden_dim) # Batch size 32

# Without gradient checkpointing, this might OOM
# output = model(input_tensor)
# loss = output.mean()
# loss.backward()

# With gradient checkpointing
from torch.utils.checkpoint import checkpoint

def segment_forward(model, x):
    # We can checkpoint specific parts of the network
    # Here, we'll checkpoint a block of layers
    mid_point = len(model.layers) // 2
    first_half = nn.Sequential(*model.layers[:mid_point])
    second_half = nn.Sequential(*model.layers[mid_point:])

    # Checkpoint the first half
    x = checkpoint(first_half, x)
    x = model.relu(x) # ReLU is applied after the block
    x = second_half(x)
    return x

model.forward = lambda x: segment_forward(model, x) # Monkey patch for demonstration

output = model(input_tensor)
loss = output.mean()
loss.backward()

print("Gradient computation successful with checkpointing.")

The problem gradient checkpointing solves is simple: deep networks require a lot of memory to store intermediate activations from the forward pass, which are needed for the backward pass. When your model is too large to fit into GPU memory, you hit an Out-Of-Memory (OOM) error. Gradient checkpointing allows you to train these models by discarding intermediate activations during the forward pass and recomputing them only when needed during the backward pass. This means you store fewer activations, significantly reducing memory usage, at the cost of increased computation time because parts of the forward pass are run twice.

Internally, torch.utils.checkpoint.checkpoint works by intercepting the forward pass of a specified module or function. Instead of storing all intermediate outputs, it only stores the inputs to that module/function. When backward() is called, PyTorch knows it needs to re-run the forward pass of the checkpointed module/function to get the intermediate values required for gradient calculation. The checkpoint function essentially wraps a module or function, and when backward() is invoked on the output of this wrapper, it triggers a re-computation of the wrapped module/function’s forward pass.

The core levers you control are:

  1. What to checkpoint: You can checkpoint individual layers, blocks of layers, or even entire models. The checkpoint function takes a callable and its arguments.
  2. Granularity: Deciding how much to checkpoint is crucial. Checkpointing too little won’t save enough memory; checkpointing too much will make training prohibitively slow. A common strategy is to checkpoint large, memory-intensive blocks of layers.
  3. The use_reentrant argument: In PyTorch versions before 2.0, checkpoint had a use_reentrant argument (defaulting to True). When False, it used a more memory-efficient recomputation strategy that avoided storing certain intermediate states, but it wasn’t compatible with all operations. PyTorch 2.0 and later versions use a more robust and often faster recomputation mechanism by default, removing the need for this flag in most cases.

The trickiest part of using gradient checkpointing effectively is understanding where the memory bottleneck truly lies. Often, it’s not just the main nn.Linear or nn.Conv2d layers, but also the intermediate activations created by complex activation functions or custom operations within your forward pass. When you checkpoint a block, you are telling PyTorch, "Don’t store the outputs of these operations; just give me the inputs, and I’ll re-run them if I need their intermediate values later for backprop." This recomputation happens during the backward pass, effectively trading compute cycles for reduced memory footprint.

If you try to checkpoint a very small segment, like a single nn.Linear layer in a deep network, you might not save enough memory to avoid OOM errors, and the overhead of re-running that single layer might even slow down training without significant memory benefits. Conversely, checkpointing too broadly can lead to excessive recomputation, making training impractically slow. The sweet spot is typically to identify the largest, most memory-hungry blocks of your network.

When you use torch.utils.checkpoint.checkpoint, you are essentially telling PyTorch: "For this specific function call, don’t keep all the intermediate results. Just remember the inputs. When the backward pass needs gradients for this function, re-run its forward pass." This re-run is what saves memory.

The next hurdle you’ll likely encounter is optimizing the placement of your checkpoints to balance memory savings and training speed.

Want structured learning?

Take the full Gpu course →