The LLM KV cache isn’t just a memory optimization; it’s the difference between a sluggish, character-by-character chatbot and something that feels almost instantaneous.

Let’s watch it in action. Imagine we’re generating text with a hypothetical LLM.

# Assume 'model' is a pre-trained LLM and 'tokenizer' is its associated tokenizer

def generate_text_with_kv_cache(prompt, max_new_tokens=50, model, tokenizer):
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    generated_ids = input_ids

    # Initialize KV cache
    kv_cache = None

    for _ in range(max_new_tokens):
        # Get model output and update KV cache
        outputs = model(generated_ids, kv_cache=kv_cache)
        logits = outputs.logits
        kv_cache = outputs.kv_cache # This is the magic!

        # Get the last token's predicted probabilities
        next_token_logits = logits[:, -1, :]

        # Sample the next token (simplified for demonstration)
        next_token_id = torch.argmax(next_token_logits, dim=-1)

        # Append the new token to the sequence
        generated_ids = torch.cat([generated_ids, next_token_id.unsqueeze(0)], dim=-1)

        # Stop if end-of-sequence token is generated
        if next_token_id.item() == tokenizer.eos_token_id:
            break

    return tokenizer.decode(generated_ids[0], skip_special_tokens=True)

# Example usage:
# prompt = "The quick brown fox"
# generated_text = generate_text_with_kv_cache(prompt, model=my_llm_model, tokenizer=my_llm_tokenizer)
# print(generated_text)

The kv_cache variable is what makes this fast. Without it, every single token generation would require the model to re-process the entire input sequence from scratch. If your prompt is 100 tokens long, and you want to generate 100 new tokens, you’d be doing 100 full passes over those initial 100 tokens. That’s incredibly wasteful.

The KV cache stores the key and value states computed for each token in the preceding sequence. When generating the next token, the model only needs to compute the key and value states for the newly added token and then combine them with the cached states from all previous tokens. This dramatically reduces computation, especially for longer sequences.

Think of it like writing a long story. If you had to re-read the entire story from the beginning every time you wanted to write just one more sentence, it would take forever. The KV cache is like keeping your place in the manuscript so you only have to focus on the sentence you’re currently writing, referencing what came before without rereading it.

The core components involved are the Key (K) and Value (V) projections within the self-attention mechanism of a Transformer model. In each attention head, the input embeddings are transformed into Query (Q), Key (K), and Value (V) vectors. The KV cache stores the K and V vectors for each token processed so far. When a new token arrives, its K and V vectors are computed, and then concatenated with the previously stored K and V vectors for all prior tokens. This combined K-V pair is then used in the attention calculation for the new token.

The exact size of the KV cache depends on the model’s architecture (number of layers, attention heads, and the dimension of the key/value vectors) and the length of the generated sequence. For a model with n_layers, n_heads, and head_dim, and a sequence length of seq_len, the KV cache for a single batch element would roughly be 2 * n_layers * n_heads * head_dim * seq_len * sizeof(dtype) * 2 (for keys and values). This can grow quite large, which is why managing it efficiently is crucial.

The one thing most people don’t realize is that the KV cache is per sequence. If you’re processing multiple independent prompts in parallel (a common scenario in inference servers), each prompt needs its own dedicated KV cache. This is often managed by the inference framework (like Hugging Face’s transformers, vLLM, or TensorRT-LLM) which allocates and reuses memory for these caches, sometimes using techniques like "PagedAttention" to handle variable sequence lengths and avoid fragmentation.

The next logical step after understanding KV caching is exploring how to further optimize inference, such as through quantization or speculative decoding.

Want structured learning?

Take the full Llm course →