Speculative decoding lets an LLM generate text up to 2-3x faster by having a smaller, faster "draft" model predict ahead, and then a larger, more accurate "verification" model check those predictions in parallel.
Let’s see it in action. Imagine we have a slow but accurate LLM, LLM_Accurate, and a fast but less accurate "draft" LLM, LLM_Draft.
First, we need to set up the drafting and verification models.
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load the accurate and draft models
model_accurate = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# For demonstration, we'll simulate a draft model. In a real scenario,
# this would be a smaller, faster model.
# Here, we'll just use the same model but configure it for faster sampling.
model_draft = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# Enable speculative decoding
# We need to tell the accurate model to use the draft model.
# This is a conceptual representation; actual implementation might involve
# specific libraries or model configurations.
# For Hugging Face transformers, this would typically be handled by
# a dedicated speculative decoding utility or by manually orchestrating
# the calls.
# Let's simulate the process:
prompt = "The quick brown fox jumps over the"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
# --- Step 1: Draft Model Generates Candidates ---
# The draft model generates k tokens ahead. Let's say k=4.
k = 4
draft_output = model_draft.generate(
input_ids,
max_new_tokens=k,
do_sample=True, # Draft models often use sampling for speed
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id # Important for generation
)
draft_tokens = draft_output[0, input_ids.shape[-1]:] # Get only the newly generated tokens
print(f"Draft generated tokens: {tokenizer.decode(draft_tokens)}")
# Example output: Draft generated tokens: lazy dog and then ran away.
# --- Step 2: Accurate Model Verifies Candidates ---
# The accurate model takes the original prompt + draft tokens and
# tries to predict the *next* token. It also calculates the probabilities
# of the draft tokens.
# We need to feed the *entire* sequence (prompt + draft tokens) to the
# accurate model to get its assessment.
verification_input_ids = torch.cat([input_ids, draft_tokens.unsqueeze(0)], dim=-1)
with torch.no_grad(): # Inference mode
outputs = model_accurate(verification_input_ids)
logits = outputs.logits
# Get the logits for the last token in the sequence (which corresponds to
# the end of the draft tokens).
last_token_logits = logits[0, -1, :] # Logits for the token *after* the draft sequence
# Calculate probabilities for the draft tokens
# This is where the verification happens. We check if the accurate model's
# predictions align with the draft tokens.
# The actual speculative decoding algorithm involves comparing the probabilities
# of the draft tokens according to the accurate model.
# A common approach is to use the probabilities of the draft tokens as predicted
# by the accurate model to decide how many to accept.
# For simplicity here, let's just show the verification process conceptually.
# In a real implementation, you'd compare the log_probs of the draft tokens
# against the log_probs predicted by the accurate model for those positions.
# A simplified acceptance criterion: If the accurate model assigns high probability
# to the sequence generated by the draft model, we accept more tokens.
# This involves sampling from the accurate model's distribution conditioned on
# the draft tokens and comparing.
# Let's assume a hypothetical function `verify_and_accept` that does this.
# In reality, this is the core of the algorithm.
# accepted_tokens, new_input_ids = speculative_decoding_algorithm(
# input_ids, draft_tokens, model_accurate, tokenizer, k
# )
# For demonstration, let's just accept all draft tokens for now.
# In a real scenario, some might be rejected and the accurate model
# generates the next token from scratch.
accepted_tokens = draft_tokens
final_generated_ids = torch.cat([input_ids, accepted_tokens], dim=-1)
print(f"Final accepted tokens: {tokenizer.decode(accepted_tokens)}")
# Example output: Final accepted tokens: lazy dog and then ran away.
# The next step would be to feed `final_generated_ids` back into the prompt
# for the next iteration, and repeat the draft-then-verify process.
The core idea is that the draft model does a lot of cheap work (predicting k tokens), and the accurate model does expensive work (verifying those k tokens). If the draft model is reasonably good, its predictions will often align with what the accurate model would have produced anyway. The verification step is designed to be much faster than generating k tokens from scratch with the accurate model.
Think of it like a committee. The "draft committee" (small, fast members) quickly suggests several ideas. Then, the "verification committee" (large, expert members) reviews these ideas. If the expert committee agrees with an idea, they adopt it. If not, they might discard it and come up with their own. Speculative decoding optimizes this review process.
Here’s how it works internally:
- Drafting Phase: The small, fast draft model takes the current sequence of tokens and generates
kcandidate tokens. This is done very quickly, often using greedy decoding or fast sampling. Let’s say the draft model proposesy_1, y_2, ..., y_k. - Verification Phase: The large, accurate verification model receives the entire sequence: the original input plus the
kdrafted tokens (x, y_1, y_2, ..., y_k). It then calculates the probability distribution over the next token (the one aftery_k). Crucially, it also implicitly or explicitly calculates the probabilities of the drafted tokensy_1, ..., y_kgiven the preceding tokens. - Acceptance/Rejection: This is the clever part. The algorithm compares the draft model’s predictions with the verification model’s assessment. A common method is to sample a token from the verification model’s distribution after seeing the full
x, y_1, ..., y_ksequence. If this sampled token matches one of the drafted tokens (y_i), we potentially accepty_i. This process continues, accepting tokensy_1, y_2, ...as long as they are "verified" by the accurate model. If at some pointy_iis not verified (e.g., the accurate model assigns it a very low probability, or a sampled token doesn’t match), then generation stops aty_{i-1}, and the accurate model generates the next token from scratch, starting fromx, y_1, ..., y_{i-1}.
The number of accepted tokens (m) can range from 0 (if even y_1 is rejected) to k (if all drafted tokens are accepted). The speedup comes because, on average, m is greater than 1. You’ve effectively generated m tokens using the draft model plus a single call to the accurate model, instead of m separate calls to the accurate model. If the draft model is good, m will be significantly larger than 1.
The "draft model" doesn’t have to be a completely different, smaller architecture. It can often be the same architecture but configured differently, for example, by reducing its precision (e.g., using float16 or int8 weights) or by using a simpler sampling strategy. The key is that its forward pass is much cheaper than the accurate model’s.
The actual implementation often involves libraries like transformers (with specific experimental features or integrations) or dedicated inference engines that support speculative decoding. The core logic revolves around efficiently calculating the probabilities from the verification model and using them to decide how many draft tokens to accept.
A crucial detail often overlooked is how the draft model is trained or selected. While using a smaller model is common, sometimes a model of the same size but trained on more data or with a slightly different objective can serve as an effective draft model. The "goodness" of the draft model is paramount; if it’s too inaccurate, the verification step will reject most tokens, and you’ll gain little to no speedup, potentially even losing speed due to the overhead.
The next challenge after successfully implementing speculative decoding is understanding its sensitivity to quantization and model parallelism.