Quantization, KV Cache, and Batching are often seen as separate knobs to tune for GPU inference, but they’re fundamentally intertwined, and optimizing one without considering the others can leave significant performance on the table.

Let’s see this in action. Imagine we’re running inference for a text generation model. Without any optimization, a single request might look like this:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

prompt = "The quick brown fox jumps over the lazy"
inputs = tokenizer(prompt, return_tensors="pt").to(device)

# Single, unbatched inference
with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

This works, but it’s slow and inefficient. Now, let’s start layering in optimizations.

Quantization: Making Weights Smaller

Quantization is about reducing the precision of the model’s weights and activations. Instead of using 32-bit floating-point numbers (FP32), we can use 16-bit floats (FP16), 8-bit integers (INT8), or even 4-bit integers (INT4). The primary benefit is a reduction in model size and memory bandwidth requirements, which directly translates to faster inference.

  • Why it matters: GPUs are incredibly fast at matrix multiplications, but they can be bottlenecked by how quickly they can fetch data from memory. Smaller data types mean more data can be loaded into the GPU’s caches and processed in the same amount of time.
  • Common Techniques:
    • Post-Training Quantization (PTQ): Quantize a pre-trained model without retraining. This is the simplest and fastest.
      • Example (using bitsandbytes for INT8):
        pip install bitsandbytes
        
        from transformers import AutoModelForCausalLM, AutoTokenizer
        import torch
        
        model_name = "gpt2"
        # Load with 8-bit quantization
        model = AutoModelForCausalLM.from_pretrained(model_name, load_in_8bit=True)
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)
        
        prompt = "The quick brown fox jumps over the lazy"
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=20)
        print(tokenizer.decode(outputs[0], skip_special_tokens=True))
        
        This reduces the model’s memory footprint by approximately 4x compared to FP32.
    • Quantization-Aware Training (QAT): Simulate quantization during training to minimize accuracy loss. This is more complex but yields better results for aggressive quantization (e.g., INT4).
  • Trade-offs: Aggressive quantization can lead to a drop in model accuracy. INT8 is generally safe for most models, while INT4 might require careful calibration or QAT.

KV Cache: Remembering Past Computations

In autoregressive models (like GPTs), generating each new token involves processing the entire sequence of previously generated tokens. This is incredibly redundant. The KV cache stores the key and value projections from the self-attention layers for previous tokens, so they don’t need to be recomputed for each new token generation step.

  • Why it matters: For long sequences, recomputing attention over the entire history is the dominant cost. The KV cache turns this into a constant-time operation (per token) after the initial prompt processing.
  • How it works: During the first token generation, the K and V vectors for all input tokens are computed and stored. For subsequent tokens, only the K and V vectors for the new token are computed, and then concatenated with the cached K and V vectors.
  • Implementation: Modern transformer libraries (like Hugging Face transformers) often handle KV caching automatically when using the generate method. You can explicitly enable it if needed, but it’s usually on by default.
    # KV cache is typically enabled by default in model.generate()
    # For demonstration, let's ensure it's active.
    # In newer transformers versions, it's managed internally.
    # If using older versions or custom implementations, you might see parameters like 'use_cache=True'.
    
  • Trade-offs: The KV cache consumes GPU memory. The size of the cache scales linearly with the batch size and the sequence length. For very large batch sizes and long sequences, KV cache can become a memory bottleneck.

Batching: Processing Multiple Requests Together

Batching involves grouping multiple inference requests together and processing them simultaneously through the model. This allows the GPU to perform computations on multiple inputs at once, significantly increasing throughput.

  • Why it matters: GPUs are highly parallel processors. Processing a single request leaves many cores idle. Batching keeps these cores busy, amortizing the computational cost over more requests.
  • Types of Batching:
    • Static Batching: All sequences in a batch have the same length. This is simple but inefficient, as shorter sequences are padded.
    • Dynamic Batching (or Continuous Batching): Requests are added to a batch as they arrive, and completed requests are removed. This is much more efficient for variable-length requests. The system dynamically forms batches of sequences that are ready to be processed.
  • Example (Conceptual Dynamic Batching): Imagine a server handling requests.
    1. Request A arrives (length 10).
    2. Request B arrives (length 15).
    3. System forms a batch [A, B] and sends it to the GPU for prompt processing.
    4. Request C arrives (length 8).
    5. After processing token 1 for A and B, A is ready to generate token 2, B is ready to generate token 2. C is waiting.
    6. System forms a batch [A (token 2), B (token 2), C (token 1)] and sends it to the GPU.
    7. This continues, with batches being dynamically formed and reformed.
  • Implementation: Libraries like vLLM or NVIDIA’s Triton Inference Server are designed for efficient dynamic batching.
  • Trade-offs: Batching increases memory usage due to the need to store multiple KV caches and intermediate activations. It also introduces latency, as requests might have to wait for a batch to form or for other requests in the batch to complete.

The Synergy: Putting It All Together

The real magic happens when these techniques are used in concert.

  1. Quantization reduces the memory footprint of the model itself. This frees up GPU VRAM.
  2. KV Cache dramatically speeds up token generation by avoiding recomputation. However, it consumes VRAM that scales with batch size and sequence length.
  3. Batching increases throughput by utilizing the GPU more fully. Dynamic batching is key for efficiency with variable request lengths.

With quantization, you can afford a larger batch size because the model weights take up less space. A larger batch size, in turn, means the overhead of starting GPU kernels is amortized, leading to higher throughput. The KV cache is essential for making token generation fast within each batched request.

Consider a scenario: If you have 8GB of VRAM and a model that’s 4GB in FP16, you can only fit a small batch. If you quantize that model to INT8, it becomes 2GB, allowing you to fit a much larger batch, thereby increasing throughput, provided your batching strategy can effectively utilize that larger batch. The KV cache, though, will still grow with the batch size and sequence length, so there’s a limit.

The next logical step after optimizing these aspects for throughput is to consider techniques that reduce the latency for individual requests, such as speculative decoding.

Want structured learning?

Take the full Gpu course →