Training massive LLMs is usually limited by GPU memory, even with a single, beefy GPU. Tensor parallelism breaks this bottleneck by splitting individual model layers across multiple GPUs.
Let’s see it in action. Imagine we have two GPUs and we want to train a small transformer layer.
import torch
import torch.nn as nn
from megatron.mp_helpers import ColumnParallelLinear, RowParallelLinear
# Assuming tensor parallelism degree is 2
tp_degree = 2
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
# Model parameters
hidden_size = 1024
ffn_hidden_size = 4096
num_attention_heads = 16
seq_length = 512
# Split the linear layers for tensor parallelism
# For ColumnParallelLinear, the weight matrix is split along the output dimension.
# For RowParallelLinear, the weight matrix is split along the input dimension.
# Example: A typical transformer block might have:
# 1. Multi-Head Attention (self-attention)
# - Query, Key, Value projections (ColumnParallelLinear)
# - Output projection (RowParallelLinear)
# 2. Feed-Forward Network
# - First linear layer (ColumnParallelLinear)
# - Second linear layer (RowParallelLinear)
# Let's focus on a simplified feed-forward network for illustration:
# Input -> Linear1 -> Activation -> Linear2 -> Output
# Linear1: hidden_size -> ffn_hidden_size
# This layer will be split across GPUs.
# If tp_degree=2, each GPU will handle a portion of the ffn_hidden_size output.
linear1 = ColumnParallelLinear(
hidden_size,
ffn_hidden_size,
bias=True,
gather_output=True, # This is important for ColumnParallelLinear
tp_degree=tp_degree,
rank=rank
)
# Activation (e.g., GELU) - no change needed for TP
activation = nn.GELU()
# Linear2: ffn_hidden_size -> hidden_size
# This layer will also be split, but differently.
# Each GPU receives a portion of the input from the previous layer's output (which was split).
# The weights are split such that each GPU computes its part of the final output.
linear2 = RowParallelLinear(
ffn_hidden_size,
hidden_size,
bias=True,
input_is_parallel=True, # This is important for RowParallelLinear
tp_degree=tp_degree,
rank=rank
)
# Dummy input
batch_size = 4
input_tensor = torch.randn(seq_length, batch_size, hidden_size) # (seq, batch, hidden)
# Forward pass
# The ColumnParallelLinear layer will split the input tensor's last dimension (hidden_size)
# and perform matrix multiplication with its *local* portion of the weight matrix.
# The results are then all-gathered if gather_output=True.
# In our example, linear1's weights (hidden_size, ffn_hidden_size) are split along the second dim.
# Each GPU will compute a portion of the ffn_hidden_size output.
# If gather_output=True, the results from all GPUs are combined.
# For ColumnParallelLinear, the weight is split along the columns (output dimension).
# W_col = [W_0 | W_1 | ... | W_{tp_degree-1}]
# Y = X @ W_col
# Y_i = X @ W_i
# All-gather Y_i to get the full Y.
# For RowParallelLinear, the weight is split along the rows (input dimension).
# W_row = [W_0; W_1; ...; W_{tp_degree-1}]
# Y = X @ W_row
# X_i is the input portion corresponding to W_i.
# Y = X_0 @ W_0 + X_1 @ W_1 + ...
# Reduce-scatter Y_i to get the final Y (split along output dimension).
# Let's assume for simplicity that the input_tensor is already distributed if needed for RowParallelLinear.
# In a real scenario, the output of the previous ColumnParallelLinear would be the input here.
# For demonstration, we'll just pass the same input_tensor to both.
# Forward pass for linear1
# The ColumnParallelLinear handles the necessary communication for its weights and output gathering.
output_linear1, _ = linear1(input_tensor) # _ is the bias if bias=True
# Activation
activated_output = activation(output_linear1)
# Forward pass for linear2
# The RowParallelLinear expects its input to be potentially split (if it's the output of a ColumnParallelLinear)
# and handles the necessary communication for its weights and output reduction.
final_output, _ = linear2(activated_output) # _ is the bias if bias=True
print(f"Rank {rank}: Output shape {final_output.shape}")
This code snippet shows how ColumnParallelLinear and RowParallelLinear from Megatron-LM wrap standard PyTorch nn.Linear modules to implement tensor parallelism. ColumnParallelLinear splits the output dimension of the linear layer, distributing the computation of Y = XA across GPUs. Each GPU computes Y_i = X A_i, where A_i is the i-th slice of the weight matrix A along the output dimension. The results are then gathered. RowParallelLinear splits the input dimension. For Y = XA, if X is already split or A is split along the input dimension, each GPU computes Y_i = X_i A_i, and these results are reduced (summed) across GPUs.
The core problem tensor parallelism solves is fitting enormous models into GPU memory. Instead of one GPU holding the entire weight matrix for a large linear layer (e.g., a 1024x4096 matrix), tensor parallelism splits this matrix. If tp_degree=2, each GPU holds half the weights. For a hidden_size=1024 and ffn_hidden_size=4096, the full weight matrix is 1024x4096. With two GPUs, each GPU holds a 1024x2048 matrix. This halves the memory requirement for that specific layer’s weights.
The mental model for tensor parallelism is that you’re partitioning the weight matrices of large linear layers (like those in MLPs and attention mechanisms) across multiple GPUs. When a forward pass happens, each GPU performs a partial matrix multiplication using its local weights. Crucially, communication is needed to combine these partial results. For ColumnParallelLinear, it’s an all-gather operation where each GPU sends its computed output slice to all other GPUs and receives theirs. For RowParallelLinear, it’s often a reduce-scatter operation where partial results are summed and then distributed. This decomposition allows you to build models that are too large to fit on a single GPU by effectively pooling the memory of multiple GPUs for specific layers.
The key levers you control are the tp_degree (how many GPUs to split across) and how you apply ColumnParallelLinear and RowParallelLinear to your model’s layers. Typically, the large linear layers within the transformer blocks—specifically, the query, key, value projections, the output projection of the multi-head attention, and the two linear layers in the feed-forward network—are the candidates for tensor parallelism. You’d use ColumnParallelLinear for layers where the output dimension is large and needs to be split (like the first MLP layer), and RowParallelLinear for layers where the input dimension is large and needs to be split (like the second MLP layer or the attention output projection).
Most people understand that tensor parallelism splits weights. What’s less obvious is how the communication patterns are dictated by the type of split and the order of operations. For a ColumnParallelLinear layer, the weight matrix W is split by columns: W = [W_0, W_1, ..., W_{p-1}]. The computation is Y = XW. Each GPU i computes Y_i = XW_i. To get the full Y, an all-gather is needed. For a RowParallelLinear layer, the weight matrix W is split by rows: W = [W_0^T, W_1^T, ..., W_{p-1}^T]^T. The computation is Y = XW. If the input X is already parallelized (which is typical after a ColumnParallelLinear layer), or if we consider the overall operation, it effectively becomes a sum of local computations: Y = \sum_i X_i W_i. A reduce-scatter operation is used to sum these up and distribute the final result. The gather_output=True on ColumnParallelLinear and input_is_parallel=True on RowParallelLinear are the flags that tell the Megatron-LM implementation which communication pattern to use.
The next hurdle you’ll encounter is efficiently combining tensor parallelism with data parallelism to scale to thousands of GPUs.