Gradient checkpointing lets you trade compute for memory by recomputing activations during the backward pass instead of storing them all during the forward pass.
Here’s a Transformer model training with gradient checkpointing enabled, showing how it reduces memory usage.
import torch
from torch import nn
from transformers import BertModel, BertConfig
# Assume a large batch size that would normally cause OOM
batch_size = 64
seq_length = 512
hidden_size = 1024
num_layers = 24
num_attention_heads = 16
# Configure a BERT-like model
config = BertConfig(
vocab_size=30522,
hidden_size=hidden_size,
num_hidden_layers=num_layers,
num_attention_heads=num_attention_heads,
intermediate_size=hidden_size * 4,
max_position_embeddings=seq_length,
type_vocab_size=2,
)
# Instantiate the model
model = BertModel(config)
# Dummy input data
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_length))
attention_mask = torch.ones(batch_size, seq_length)
token_type_ids = torch.zeros(batch_size, seq_length)
# --- Without Gradient Checkpointing ---
print("--- Without Gradient Checkpointing ---")
# Move model and data to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
token_type_ids = token_type_ids.to(device)
# Zero gradients
model.zero_grad()
# Forward pass
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
loss = outputs.last_hidden_state.mean() # Dummy loss
# Check memory before backward pass
initial_memory_usage = torch.cuda.memory_allocated(device)
print(f"Memory allocated before backward: {initial_memory_usage / 1024**2:.2f} MB")
# Backward pass (will store all intermediate activations)
try:
loss.backward()
final_memory_usage_no_gc = torch.cuda.memory_allocated(device)
print(f"Memory allocated after backward: {final_memory_usage_no_gc / 1024**2:.2f} MB")
print("Backward pass successful without gradient checkpointing.")
except RuntimeError as e:
print(f"Backward pass failed: {e}")
print("Likely ran out of GPU memory.")
# --- With Gradient Checkpointing ---
print("\n--- With Gradient Checkpointing ---")
# Reload model to ensure clean state
model = BertModel(config).to(device)
# Enable gradient checkpointing for all modules
from torch.utils.checkpoint import checkpoint_needs_grad, checkpoint
from transformers.modeling_utils import ModuleUtilsMixin
# We need to iterate through the model layers and apply checkpointing.
# For BERT, the main layers are within the `encoder.layer` attribute.
for layer_module in model.encoder.layer._modules.values():
layer_module.gradient_checkpointing_enable()
# Dummy input data (already on GPU)
# input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_length)).to(device)
# attention_mask = torch.ones(batch_size, seq_length).to(device)
# token_type_ids = torch.zeros(batch_size, seq_length).to(device)
# Zero gradients
model.zero_grad()
# Forward pass
outputs_gc = model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
loss_gc = outputs_gc.last_hidden_state.mean() # Dummy loss
# Check memory before backward pass
# Memory usage might be slightly higher initially due to model definition overhead
initial_memory_usage_gc = torch.cuda.memory_allocated(device)
print(f"Memory allocated before backward (GC enabled): {initial_memory_usage_gc / 1024**2:.2f} MB")
# Backward pass (will recompute activations as needed)
loss_gc.backward()
final_memory_usage_gc = torch.cuda.memory_allocated(device)
print(f"Memory allocated after backward (GC enabled): {final_memory_usage_gc / 1024**2:.2f} MB")
print("Backward pass successful with gradient checkpointing.")
# Clean up GPU memory
del model, input_ids, attention_mask, token_type_ids
if torch.cuda.is_available():
torch.cuda.empty_cache()
Gradient checkpointing is a memory-saving technique where instead of storing all intermediate activations from the forward pass to use during the backward pass, you only store a subset. When the backward pass needs an activation that wasn’t stored, it recomputes it on the fly. This trades increased computation time for significantly reduced memory usage, making it possible to train larger models or use larger batch sizes.
The core idea is to divide the model into segments. During the forward pass, only the activations at the boundaries of these segments are saved. When the backward pass reaches a segment whose boundary activations were saved but intermediate ones were not, it recomputes those intermediate activations from the saved boundary activations. This process is applied recursively.
The torch.utils.checkpoint.checkpoint function is the primary tool. You wrap specific parts of your model (often entire layers or blocks) with this function. For Hugging Face Transformers, this usually means iterating through the encoder.layer (or similar sequential blocks) and applying gradient_checkpointing_enable() to each sub-module. This function then automatically handles the recomputation logic during the backward pass.
The transformers library makes this easier with model.gradient_checkpointing_enable(). When called on a PreTrainedModel, it intelligently finds the appropriate layers (like BertLayer or GPT2Block) and applies the checkpointing logic. You can also apply it more granularly by iterating through model.encoder.layer._modules.values() for a BERT model, as shown in the example.
The primary benefit is clear: fitting larger models or larger batch sizes onto your GPU. If you’re hitting CUDA out of memory errors during training, especially after the forward pass completes but before the backward pass finishes, gradient checkpointing is your first line of defense.
Beyond just enabling it, you can control how checkpointing works. The checkpoint function has arguments like preserve_rng_state (defaults to True, which can add a small overhead) and use_reentrant (defaults to True in newer PyTorch versions, potentially more efficient but might have compatibility issues with some custom autograd functions). For most Hugging Face models, the default settings are sufficient.
One crucial aspect often overlooked is that gradient checkpointing does not save memory during the forward pass itself, only during the backward pass. The memory needed to store the model parameters and the input data remains the same. The savings come from avoiding the storage of intermediate feature maps across the entire network.
The next logical step after successfully implementing gradient checkpointing is to explore techniques for further optimizing training speed, such as mixed-precision training with torch.cuda.amp or distributed training strategies.