Tensor parallelism allows you to split individual layers of a neural network across multiple GPUs.
Let’s see it in action. Imagine a simple matrix multiplication: Y = XW. If W is too large to fit on a single GPU, we can split W column-wise: W = [W_1 | W_2]. Then Y = X[W_1 | W_2] = [XW_1 | XW_2]. This means GPU 1 computes Y_1 = XW_1 and GPU 2 computes Y_2 = XW_2. Both GPUs need X, which is a broadcast operation. The results Y_1 and Y_2 are then concatenated to form Y.
This is the core idea behind tensor parallelism. For a standard linear layer Y = XA^T + b, we can split the weight matrix A column-wise across GPUs. If we have N GPUs, A is split into N matrices A_1, A_2, ..., A_N, where each A_i has 1/N of the columns of A. Each GPU i then computes Y_i = XA_i^T. Since all GPUs need the input X, X is broadcast to all devices. After the computation, the partial results Y_1, Y_2, ..., Y_N are concatenated along the column dimension to form the final output Y.
This technique is particularly useful for transformer models where the feed-forward network (FFN) and attention layers contain large weight matrices. For example, in a transformer’s FFN, a typical layer is FFN(x) = max(0, xW1 + b1)W2 + b2. Here, both W1 and W2 can be split. If W1 has shape (d_model, d_ff) and W2 has shape (d_ff, d_model), we can split W1 column-wise and W2 row-wise.
Let W1 = [W1_1 | W1_2] and W2 = [[W2_1], [W2_2]] where W1_i has 1/N columns and W2_i has 1/N rows.
GPU 1 computes H_1 = XW1_1^T and GPU 2 computes H_2 = XW1_2^T.
The intermediate result is H = H_1 + H_2. This addition is an all_reduce operation if we think about the gradients, or a simple concatenation followed by a reshaped all_gather for the forward pass if we consider the output of the activation function.
Then, Y_1 = H_1 W2_1 and Y_2 = H_2 W2_2. Wait, this is not correct.
Let’s re-evaluate the FFN split.
FFN(x) = relu(xW1 + b1)W2 + b2
Suppose W1 is split column-wise W1 = [W1_1 | W1_2] and W2 is split row-wise W2 = [[W2_1], [W2_2]].
GPU 1 computes H_1 = xW1_1^T and GPU 2 computes H_2 = xW1_2^T.
The intermediate result H = H_1 + H_2 is not computed this way. The input x is broadcast.
GPU 1 computes H_1 = xW1_1^T.
GPU 2 computes H_2 = xW1_2^T.
The full intermediate matrix H is formed by concatenating H_1 and H_2 along the column dimension: H = [H_1 | H_2]. This is an all_gather operation on the output of the first linear transformation.
Then, H is passed through the activation function (e.g., ReLU).
Now, for the second linear layer, W2 is split row-wise: W2 = [[W2_1], [W2_2]].
GPU 1 computes Y_1 = H W2_1.
GPU 2 computes Y_2 = H W2_2.
The final output Y = Y_1 + Y_2. This sum is an all_reduce operation.
The key operations involved are all_gather (to combine partial results after a split matrix multiplication) and reduce_scatter (to distribute gradients after an all_reduce operation). For linear layers where the weight matrix W is split column-wise, the input X is broadcast, and the output Y is an all_gather of partial results Y_i = X W_i^T. For linear layers where W is split row-wise, the input X is an all_gather of partial inputs X_i, and the output Y is computed via an all_reduce of partial results Y_i = X_i W^T.
In practice, libraries like Megatron-LM or DeepSpeed implement these tensor parallel operations efficiently. For example, a ColumnParallelLinear layer in Megatron-LM handles the column-wise split of the weight matrix, performing an all_gather on the input and then a local matrix multiplication. A RowParallelLinear layer handles the row-wise split, performing a local matrix multiplication and then an all_reduce on the output.
The choice of splitting strategy (column-wise vs. row-wise) depends on the specific layer and the desired communication pattern. For transformer blocks, it’s common to split the weight matrix of the FFN layers. The self-attention mechanism also has large weight matrices (query, key, value, output projections) that can be split.
Consider a linear layer y = xW.
If W is split column-wise into W = [W_1 | W_2], then y = x[W_1 | W_2] = [xW_1 | xW_2].
This requires an all_gather on x (implicitly done by the framework if x is already distributed) and then local computations y_1 = xW_1 and y_2 = xW_2. The final y is the concatenation of y_1 and y_2.
For gradients, dL/dW = x^T dL/dy. If dL/dy is split into [dL/dy_1 | dL/dy_2], then dL/dW_1 = x^T dL/dy_1 and dL/dW_2 = x^T dL/dy_2. This is a simple local computation.
However, if W is split column-wise, then dL/dx = dL/dy [W_1 | W_2]^T = dL/dy W^T. If dL/dy is already split, this means dL/dx is computed by summing up contributions. dL/dx = dL/dy_1 W_1^T + dL/dy_2 W_2^T. This summation across GPUs is an all_reduce operation on dL/dx.
If W is split row-wise into W = [[W_1], [W_2]], then y = [[x_1], [x_2]]W = x_1W_1 + x_2W_2. This requires x to be split into [x_1, x_2] (an all_gather of x followed by a split or reduce_scatter depending on the perspective) and then an all_reduce on the results y_1 = x_1W_1 and y_2 = x_2W_2 to sum them up: y = y_1 + y_2.
For gradients, dL/dW = x^T dL/dy. This is a local computation.
And dL/dx = dL/dy W^T. If W is split row-wise, then dL/dx is computed by dL/dx = dL/dy * [[W_1^T], [W_2^T]]. This means dL/dx is split, and each part is computed locally. dL/dx_1 = dL/dy W_1^T and dL/dx_2 = dL/dy W_2^T. The final dL/dx is the concatenation of dL/dx_1 and dL/dx_2. This is an all_gather operation on dL/dx.
The communication overhead is a critical factor. Column-wise splitting requires an all_gather on the input and an all_reduce on the gradient of the input. Row-wise splitting requires an all_reduce on the output and an all_gather on the gradient of the input. The choice depends on the relative sizes of tensors and the network topology.
A common strategy for transformers is to split the FFN layers. The MLP consists of two linear layers. The first linear layer (W1) is often split column-wise (ColumnParallelLinear), and the second linear layer (W2) is split row-wise (RowParallelLinear). This balances the communication patterns.
The exact implementation details involve careful management of tensor shapes and communication primitives. For instance, when splitting a weight matrix W of shape (out_features, in_features) column-wise across N GPUs, each GPU will have a weight matrix of shape (out_features, in_features / N). The input X of shape (batch_size, in_features) is broadcast, and each GPU computes Y_i = X W_i^T, resulting in Y_i of shape (batch_size, in_features / N). The final output Y is an all_gather of Y_i along the last dimension, forming a (batch_size, in_features) tensor.
When splitting W row-wise, each GPU has a weight matrix of shape (out_features / N, in_features). The input X of shape (batch_size, in_features) is effectively split across GPUs (via an all_gather followed by a reduce_scatter or vice-versa depending on perspective). Each GPU i computes Y_i = X_i W_i^T, where X_i is the portion of X relevant to GPU i. The final output Y is an all_reduce sum of Y_i.
The most surprising thing is that tensor parallelism often requires more communication than pipeline parallelism for the same number of GPUs, but it can achieve higher hardware utilization due to better load balancing across computational cores within a single GPU.
Here’s a conceptual snippet of how ColumnParallelLinear might work with PyTorch and distributed primitives:
import torch
import torch.distributed as dist
from torch import nn
class ColumnParallelLinear(nn.Module):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.device = device
self.dtype = dtype
world_size = dist.get_world_size()
self.num_gpus = world_size
# Split out_features across GPUs
self.out_features_per_gpu = out_features // self.num_gpus
if out_features % self.num_gpus != 0:
raise ValueError("out_features must be divisible by world_size for column parallelism")
# Initialize weight and bias on each GPU
self.weight = nn.Parameter(torch.empty(self.out_features_per_gpu, in_features, device=device, dtype=dtype))
if bias:
self.bias = nn.Parameter(torch.empty(self.out_features_per_gpu, device=device, dtype=dtype))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
# Standard initialization, but scaled by num_gpus for variance
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
nn.init.uniform_(self.bias, -1/fan_in, 1/fan_in)
def forward(self, x):
# Input x is assumed to be on the same device as the module
# x has shape (batch_size, in_features)
batch_size = x.shape[0]
# Perform local matrix multiplication
# y_local has shape (batch_size, out_features_per_gpu)
y_local = nn.functional.linear(x, self.weight, self.bias)
# Gather results from all GPUs
# all_gather_into_tensor requires a pre-allocated tensor
y_gathered = torch.empty(batch_size, self.out_features, device=x.device, dtype=x.dtype)
dist.all_gather_into_tensor(y_gathered, y_local, async_op=False)
return y_gathered
# Example usage:
# Assume 2 GPUs are initialized and this code runs on both
# world_size = 2, rank = 0 or 1
# in_features = 1024, out_features = 4096 (so each GPU gets 2048 out_features)
# model = ColumnParallelLinear(1024, 4096, device='cuda', dtype=torch.float16)
# input_tensor = torch.randn(32, 1024, device='cuda', dtype=torch.float16)
# output_tensor = model(input_tensor)
# print(output_tensor.shape) # Should be torch.Size([32, 4096])
The actual implementation in libraries like Megatron-LM is more sophisticated, handling gradient computations and integration with other parallelism strategies. The key takeaway is the decomposition of a single layer’s computation and communication across multiple devices.
The next challenge is often optimizing the communication patterns, especially when combining tensor parallelism with data or pipeline parallelism, leading to complex interdependencies between all_gather, reduce_scatter, and all_reduce operations across different ranks and stages.