Async processing is the secret sauce that lets LLM inference services handle way more requests than you’d expect for the GPU power they have.

Let’s say you’ve got a beefy A100 GPU, capable of 100 tokens/second for a single LLM request. You’re getting slammed with 100 requests per second, each asking for 100 tokens. If you handle them one by one, synchronously, your GPU can only process one request at a time. It finishes request 1 (1 second), then starts request 2, and so on. By the time request 100 is done, you’re already 100 seconds behind, and your GPU is idle for most of that time waiting for the previous request to finish.

Now, imagine you have a queue. When requests come in, you don’t immediately send them to the GPU. You stash them in a queue. A separate process, let’s call it the "batcher," watches this queue. When it sees a few requests (say, 8, or 16, or whatever fits nicely on your GPU), it bundles them together into a batch. This batch is then sent to the GPU. The magic is that the GPU can process this batch of 8 requests simultaneously, or at least appear to. It’s not doing 8x the work; it’s just using its parallel processing capabilities more effectively. While the GPU is chugging through that batch, the batcher is already assembling the next batch from the queue.

Here’s how a typical async LLM inference setup looks in practice. We’ll use a common pattern involving a web server (like FastAPI), a message queue (like Redis with RQ), and the LLM inference engine itself.

1. The Web Server (FastAPI)

This is the entry point for your requests. It receives the prompt and any parameters. Instead of processing it directly, it puts the request onto a queue and immediately returns a job ID to the client.

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import redis
from rq import Queue
import uuid

app = FastAPI()

# Connect to Redis
redis_conn = redis.Redis(host='localhost', port=6379, db=0)
# Create an RQ queue
q = Queue(connection=redis_conn)

class LLMRequest(BaseModel):
    prompt: str
    max_tokens: int = 50

@app.post("/generate/")
async def generate_text(request: LLMRequest):
    job_id = str(uuid.uuid4())
    # Enqueue the task with a job ID
    q.enqueue(run_inference_worker, request.prompt, request.max_tokens, job_id, _id=job_id)
    return {"job_id": job_id, "status": "queued"}

@app.get("/status/{job_id}")
async def get_job_status(job_id: str):
    job = q.fetch_job(job_id)
    if job is None:
        raise HTTPException(status_code=404, detail="Job not found")
    return {"job_id": job_id, "status": job.get_status(), "result": job.result}

# Dummy function for the worker, this would be your actual inference call
def run_inference_worker(prompt: str, max_tokens: int, job_id: str):
    print(f"Processing job {job_id}: {prompt[:30]}...")
    # In a real scenario, this would call your LLM model
    # For demonstration, we'll simulate work and return a result
    import time
    time.sleep(2) # Simulate inference time
    result = f"Generated text for '{prompt[:20]}...' with max_tokens={max_tokens}"
    return result

2. The Message Queue (Redis with RQ)

Redis acts as the broker, holding tasks until a worker is ready. RQ (Redis Queue) is a simple Python library for managing these tasks.

3. The Worker Process

This is where the actual LLM inference happens. You’ll have one or more worker processes running, which poll the queue for new jobs. When a job is picked up, the worker executes the run_inference_worker function. Crucially, the worker doesn’t just process one request at a time. It’s configured to pull multiple jobs from the queue and batch them for the LLM.

# worker.py
from rq import Worker, Queue, Connection
import redis
import time

# Connect to Redis
redis_conn = redis.Redis(host='localhost', port=6379, db=0)

# This function needs to be defined or imported by the worker
# It's the same function as defined in the FastAPI app for consistency
def run_inference_worker(prompt: str, max_tokens: int, job_id: str):
    print(f"Processing job {job_id}: {prompt[:30]}...")
    # In a real scenario, this would call your LLM model
    # For demonstration, we'll simulate work and return a result
    time.sleep(2) # Simulate inference time
    result = f"Generated text for '{prompt[:20]}...' with max_tokens={max_tokens}"
    return result

if __name__ == "__main__":
    with Connection(redis_conn):
        # Listen to the default queue
        # In a real-world scenario, you might have multiple queues for different models or priorities
        worker = Worker(['default'], connection=redis_conn)
        print("Worker started. Listening for jobs...")
        worker.work()

To run the worker: rq worker in your terminal.

The Batching Logic (The Core Concept)

The run_inference_worker function shown above is a simplification. A real-world batching worker would look more like this:

# batched_worker.py
from transformers import pipeline
import torch
import redis
from rq import Worker, Queue, Connection
import time
import uuid

# --- Configuration ---
MODEL_NAME = "gpt2" # Replace with your actual LLM model
MAX_BATCH_SIZE = 8
GPU_DEVICE = 0 # 0 for the first GPU, -1 for CPU

# --- Initialize Model ---
# Using Hugging Face's pipeline for simplicity.
# The key is how 'batch_size' is handled internally or how you manage it.
# For true async batching, you'd typically manage this at a lower level.
# For demonstration, we'll simulate batching by processing multiple jobs at once.

# Connect to Redis
redis_conn = redis.Redis(host='localhost', port=6379, db=0)
q = Queue(connection=redis_conn)

# A simple in-memory queue for jobs to be batched
pending_jobs = []
job_map = {} # Map job_id to original request details

def process_batch():
    global pending_jobs
    if not pending_jobs:
        return

    # Get jobs from the pending queue
    current_batch_jobs = pending_jobs[:MAX_BATCH_SIZE]
    pending_jobs = pending_jobs[MAX_BATCH_SIZE:]

    prompts = [job['prompt'] for job in current_batch_jobs]
    job_ids = [job['job_id'] for job in current_batch_jobs]

    print(f"Processing batch of {len(prompts)} jobs: {job_ids}")

    # --- Actual LLM Inference (Simulated) ---
    # In a real scenario, you'd pass these prompts to your batched LLM model.
    # For Hugging Face pipelines, batching is often handled automatically if you pass a list.
    # For more advanced control (e.g., padding, different sequence lengths),
    # you might use the model directly with torch.
    try:
        # Example using a Hugging Face pipeline (assumes batching support)
        # You might need to configure padding and truncation appropriately.
        # For true custom batching, you'd use model.generate() with tokenized inputs.
        # This example simulates the *outcome* of batch processing.
        generated_texts = [f"Generated for {p[:10]}..." for p in prompts] # Simulate generation
        time.sleep(3) # Simulate batch inference time

        # Update results for each job
        for i, job_id in enumerate(job_ids):
            job_data = job_map.get(job_id)
            if job_data:
                # In a real RQ setup, you'd update the job result directly.
                # For this simulation, we'll just print and assume it gets stored.
                print(f"Job {job_id} completed with result: {generated_texts[i]}")
                # Example: job_data['job'].set_status(Job.SUCCESS)
                # Example: job_data['job'].result = generated_texts[i]
            else:
                print(f"Warning: Job data not found for {job_id}")
    except Exception as e:
        print(f"Error processing batch: {e}")
        # Mark failed jobs
        for job_id in job_ids:
            job_data = job_map.get(job_id)
            if job_data:
                print(f"Job {job_id} failed: {e}")
                # Example: job_data['job'].set_status(Job.FAILED)
                # Example: job_data['job'].exc_info = str(e)
            else:
                print(f"Warning: Job data not found for failed job {job_id}")

    # Clean up processed jobs from map
    for job_id in job_ids:
        job_map.pop(job_id, None)

def worker_loop():
    print("Batched worker started. Listening for jobs...")
    while True:
        # Fetch jobs from RQ queue
        job = q.get_job()
        if job:
            print(f"Received job {job.id}")
            # Add job details to our pending queue
            pending_jobs.append({
                'job_id': job.id,
                'prompt': job.args[0], # Assuming prompt is the first arg
                'max_tokens': job.args[1] # Assuming max_tokens is the second arg
            })
            job_map[job.id] = {'job': job} # Store job object for later update

            # If the batch is full or we've waited long enough (optional timeout)
            if len(pending_jobs) >= MAX_BATCH_SIZE:
                process_batch()
        else:
            # If no jobs, process any pending jobs if they've accumulated
            # This is a simple "flush" mechanism. A real system might have timeouts.
            if pending_jobs:
                print("No new jobs, processing pending batch...")
                process_batch()
            time.sleep(1) # Wait a bit before checking the queue again

if __name__ == "__main__":
    # This is a simplified worker. A real one would use RQ's Worker class
    # and integrate batching more elegantly. This shows the core logic.
    worker_loop()

The key here is MAX_BATCH_SIZE. By setting this to a value that comfortably fits on your GPU (e.g., 8, 16, 32, 64 depending on model and hardware), you maximize GPU utilization. The GPU is no longer waiting for individual requests; it’s processing a steady stream of batches.

Why This Cuts Costs

GPUs are expensive. You pay for them whether they’re fully utilized or sitting idle. By batching, you dramatically increase the throughput of each GPU.

  • Synchronous: 100 requests/sec * 100 tokens/sec = 10,000 tokens/sec theoretical max. In reality, much lower due to overhead and idle time.
  • Asynchronous Batched: If your batch size is 8 and you can process 8 requests in the same amount of time it took for one before (or even slightly longer, say 1.5s instead of 1s), your throughput jumps. If one request took 1 second, and a batch of 8 takes 1.5 seconds:
    • You’re processing 8 requests in 1.5 seconds.
    • That’s 8 / 1.5 = ~5.3 requests per second per batch slot.
    • With 8 slots, that’s 5.3 * 8 = ~42 requests per second.
    • Your GPU is now handling ~42 requests per second instead of 1.
    • This means you might need 1/4th or 1/5th the number of GPUs for the same load, directly translating to cost savings.

The Counterintuitive Part:

Many people assume that batching requests means every request in the batch has to wait for the slowest request in that batch. While technically true for a single batch execution, the overall system throughput is what matters for cost. The time saved by keeping the GPU saturated and minimizing idle periods far outweighs the slight increase in latency for individual requests that happen to be in a batch with a slightly longer sequence or more complex generation. You’re trading a small, predictable increase in latency for a massive increase in cost-efficiency and overall request handling capacity.

The next step is exploring dynamic batching, where the batch size isn’t fixed but adjusts based on incoming traffic and GPU load.

Want structured learning?

Take the full Llm course →