FlashAttention 2 doesn’t just make attention faster; it fundamentally changes how attention is computed by fusing operations and optimizing memory access, turning a memory-bound bottleneck into a compute-bound one.

Let’s see it in action. Imagine training a moderately sized transformer model, say with 1 billion parameters, on a dataset like C4. Without FlashAttention 2, a single training step might take 5 seconds. With FlashAttention 2, that same step can drop to under 2 seconds, a more than 2.5x speedup. This isn’t just about shaving off milliseconds; it means finishing a full training run in days instead of weeks, or iterating on model architectures much more rapidly.

Here’s the core idea: standard attention, especially for long sequences, is incredibly memory-bandwidth bound. For each token, you compute attention scores with every other token. This involves a massive matrix multiplication (Q @ K.T) and then another (Scores @ V), both of which require reading large amounts of data from HBM (High Bandwidth Memory) on the GPU, writing intermediate results back, and then reading again. FlashAttention 2 tackles this by performing these operations on-chip in SRAM (Static Random-Access Memory), which is much faster. It does this through a technique called "tiling," where it breaks down the large matrices into smaller blocks that fit into SRAM. It then computes the attention for these blocks, fuses the softmax and dropout operations directly within the SRAM, and only writes the final, reduced output back to HBM. This drastically reduces the number of slow HBM reads and writes.

The primary lever you control is simply enabling FlashAttention 2 in your deep learning framework. For PyTorch, this typically looks like this:

import torch
from flash_attn import flash_attn_func

# Assuming q, k, v are your batched query, key, and value tensors
# q.shape: (batch_size, seq_len, num_heads, head_dim)
# k.shape: (batch_size, seq_len, num_heads, head_dim)
# v.shape: (batch_size, seq_len, num_heads, head_dim)

# Standard attention (for comparison)
# attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)

# Using FlashAttention 2
# Note: For causal attention, use is_causal=True
# For multi-query/grouped-query attention, you'll need to adjust shapes and potentially use different parameters
attn_output = flash_attn_func(q, k, v, causal=False, dropout_p=0.0)

You’ll need to install the flash-attn library. The installation often requires a CUDA-enabled GPU and can be done via pip: pip install flash-attn. The library is optimized for specific GPU architectures (like Ampere and Hopper), so performance gains are most pronounced on newer hardware. The flash_attn_func takes your query, key, and value tensors as input, along with optional arguments like causal (for decoder-style attention) and dropout_p.

The key internal mechanism that makes this possible is the optimized kernel implementation. Instead of relying on generic matrix multiplication kernels that do many HBM round trips, FlashAttention 2 uses custom CUDA kernels. These kernels are designed to:

  1. Tiled Computation: Load blocks of Q, K, and V into SRAM.
  2. Fused Operations: Compute Q @ K.T, apply softmax, and multiply by V within SRAM for each tile. This avoids writing intermediate QK^T matrices to HBM.
  3. Online Softmax: The softmax is computed incrementally as tiles are processed, avoiding the need to materialize the entire attention score matrix.
  4. Efficient Reduction: The results from different tiles are carefully reduced in SRAM to produce the final output.

This fusion and on-chip processing are what create the dramatic speedup and memory savings. For example, if you have a sequence length of 4096 and a batch size of 32 with 12 heads and a head dimension of 64, a standard attention mechanism might materialize an intermediate attention score matrix of 32 * 12 * 4096 * 4096 elements, which is enormous and spills out of fast on-chip memory. FlashAttention 2 avoids this entirely by processing smaller chunks.

The aspect most people miss is how FlashAttention 2 handles the softmax normalization across tiles. Standard softmax requires knowing the maximum value in the entire row to normalize correctly. When you’re tiling, you don’t have the whole row available at once. FlashAttention 2 employs an "online softmax" algorithm. It keeps track of the current maximum and sum of exponentials seen so far for a row. As new tiles arrive, it updates these statistics and re-normalizes the partial results from previous tiles. This ensures the final output is correctly normalized without ever needing to load the full attention score matrix.

The next frontier is understanding how to best integrate this into mixed-precision training and exploring its implications for even longer sequence lengths.

Want structured learning?

Take the full Huggingface course →