Prefill Decode Primer

LLM inference optimizations

What happens when you type a message and hit enter on your ChatGPT app? What happens when you call Claude API?

Lets dive in and take a look at how LLM inference is performed at large scale. Also, this post assumes that you have some background in transformers, self-attention and Pytorch.

Lets start with a simple toy LLM with just an attention block. This helps us focus on the core levers that we need to look for when attempting to optimize the inference of these models

Codepython
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ToyLLM(nn.Module):
    def __init__(
        self,
        vocab_size: int = 100,
        dim: int = 32,
        max_seq_len: int = 2048,
        device: str = "cpu"
    ):
        super().__init__()
        self.tok_embed = nn.Embedding(vocab_size, dim, device=device)
        self.pos_embed = nn.Embedding(max_seq_len, dim, device=device)

        ## Attention block
        self.q_proj = nn.Linear(dim, dim, bias=False, device=device)
        self.k_proj = nn.Linear(dim, dim, bias=False, device=device)
        self.v_proj = nn.Linear(dim, dim, bias=False, device=device)

        self.out = nn.Linear(dim, vocab_size, bias=False, device=device)

        self.dim = dim
        self.device = device

    def forward(self, input_ids: torch.Tensor):
        B, T = input_ids.shape

        positions = torch.arange(T, device=self.device).unsqueeze(0)
        tok_emb = self.tok_embed(input_ids)
        pos_emb = self.pos_embed(positions)
        x = tok_emb + pos_emb

        ## Calculate self-attention
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.dim ** 0.5)
        mask = torch.tril(torch.ones(T, T, device=self.device)).bool()
        scores = scores.masked_fill(~mask, float("-inf"))

        attn = F.softmax(scores, dim=-1)

        hidden = torch.matmul(attn, v)
        logits = self.out(hidden)

        return logits

Lets take our toy LLM and pass a dummy tokenized prompt through it. The prompt here maps to a user input, an instruction like "What is the weather today?", "Generate code" etc. A forward pass through the model returns an array of logits, i.e. each position predicts the next token for the prefix ending at that position, or in other words the model is constantly trying to predict what comes next.

Codepython
model = ToyLLM(device=device)
model.eval()

prompt = torch.tensor([[1, 2, 3, 4]], device=device)
logits = model(prompt)
print(logits.shape)
Output
torch.Size([1, 4, 100])

In the above example, we have already provided the first four tokens. So we are interested in the 5th token in the sequence, which is the first token generated by the LLM. This can be obtained by taking the argmax on the logits at the last prompt position. Doing an argmax allows us to greedily pick the next token which has the highest probability

Codepython
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
prompt = torch.cat([prompt, next_token], dim=1)
print(prompt)
Output
tensor([[ 1,  2,  3,  4, 14]], device='cuda:0')

Now that we have our 5th token, we can append it to the original input tokens and do another forward pass through the model to get the 6th token. We can keep iterating this process to generate more tokens until we hit our max_tokens generation limit or the model generates a special [EOS] token indicating that the generation is complete. This is how LLMs generate tokens in an auto-regressive fashion.

Codepython
logits = model(prompt)
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
prompt = torch.cat([prompt, next_token], dim=1)
print(prompt)
Output
tensor([[ 1,  2,  3,  4, 14, 14]], device='cuda:0')

If you notice in the above code, after generating a new token, when we append it to the input to predict the next token, we regenerate all the KV values in our attention block that were computed in the previous forward pass. We keep re-computing the KV values for each forward pass instead of reusing them, or in other words we are wasting FLOPs. Lets do a simple timer to see how long it takes generate 100 tokens for a batch size of 10K (reported numbers are from a T4 GPU)

Codepython
import time

max_tokens = 100
prompt = torch.randint(low=0, high=100, size=(10000, 50), device=device)
print(f"Prompt shape: {prompt.shape}")
with torch.no_grad():
    start = time.perf_counter()
    for _ in range(max_tokens):
        logits = model(prompt)
        next_token_logits = logits[:, -1, :]
        next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        prompt = torch.cat([prompt, next_token], dim=1)
    end = time.perf_counter()
print(f"Result shape: {prompt.shape}, Time: {end - start:.6f} seconds")
Output
Prompt shape: torch.Size([10000, 50])
Result shape: torch.Size([10000, 150]), Time: 3.579580 seconds

Instead of re-computing the KV values, lets create a new version of the model which reuses the KV values computed in the previous forward pass. For the new token N, we compute fresh Q/K/V projections and calculate attention for the new query against all cached keys. For 1..N-1 tokens, we reuse the cached K/V tensors instead of recomputing them

Transformer Block

writes
Q
K
V
Create KV rows 1..2, then use the last logits to generate t3.

KV Cache

2 rows

KV1

append

KV2

append

Input Tokens

parallel x2

The

model

Codepython
class ToyLLMWithCaching(nn.Module):
    def __init__(
        self,
        vocab_size: int = 100,
        dim: int = 32,
        max_seq_len: int = 2048,
        device: str = "cpu"
    ) -> None:
        super().__init__()
        self.tok_embed = nn.Embedding(vocab_size, dim, device=device)
        self.pos_embed = nn.Embedding(max_seq_len, dim, device=device)

        self.q_proj = nn.Linear(dim, dim, bias=False, device=device)
        self.k_proj = nn.Linear(dim, dim, bias=False, device=device)
        self.v_proj = nn.Linear(dim, dim, bias=False, device=device)

        self.out = nn.Linear(dim, vocab_size, bias=False, device=device)

        self.dim = dim
        self.device = device

    def forward(
        self,
        input_ids: torch.Tensor,
        kv_cache: dict[str, torch.Tensor] | None = None
    ):
        B, T = input_ids.shape
        past_len = 0 if kv_cache is None else kv_cache["k"].shape[1]

        positions = torch.arange(past_len, past_len + T, device=self.device).unsqueeze(0)
        tok_emb = self.tok_embed(input_ids)
        pos_emb = self.pos_embed(positions)
        x = tok_emb + pos_emb

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        if kv_cache is not None:
            k = torch.cat([kv_cache["k"], k], dim=1)
            v = torch.cat([kv_cache["v"], v], dim=1)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.dim ** 0.5)

        if kv_cache is None:
            mask = torch.tril(torch.ones(T, T, device=self.device)).bool()
            scores = scores.masked_fill(~mask, float("-inf"))
        else:
            q_pos = torch.arange(past_len, past_len + T, device=self.device).unsqueeze(1)
            k_pos = torch.arange(k.shape[1], device=self.device).unsqueeze(0)
            mask = (k_pos <= q_pos).bool()
            scores = scores.masked_fill(~mask, float("-inf"))

        attn = F.softmax(scores, dim=-1)

        hidden = torch.matmul(attn, v)
        logits = self.out(hidden)
        cache = {"k": k, "v": v}

        return logits, cache

model = ToyLLMWithCaching(device=device)
model.eval()
Output
ToyLLMWithCaching(
  (tok_embed): Embedding(100, 32)
  (pos_embed): Embedding(2048, 32)
  (q_proj): Linear(in_features=32, out_features=32, bias=False)
  (k_proj): Linear(in_features=32, out_features=32, bias=False)
  (v_proj): Linear(in_features=32, out_features=32, bias=False)
  (out): Linear(in_features=32, out_features=100, bias=False)
)

Lets time how long it takes to process the same input size as before when we reuse the KV values (a.k.a use KV cache)

Codepython
import time

@torch.no_grad()
def run_model(model, prompt, max_tokens):
    logits, cache = model(prompt) ## Prefill
    generated = []
    next_token_logits = logits[:, -1, :]
    next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
    generated.append(next_token)

    for _ in range(max_tokens-1):
        logits, cache = model(next_token, cache) ## Decode
        next_token_logits = logits[:, -1, :]
        next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        generated.append(next_token)
    result = torch.cat([prompt] + generated, dim=1)
    return result

max_tokens = 100
prompt = torch.randint(low=0, high=100, size=(10000, 50), device=device)
print(f"Prompt shape: {prompt.shape}")

start = time.perf_counter()
result = run_model(model, prompt, max_tokens)
end = time.perf_counter()
print(f"Result shape: {result.shape}, Time: {end - start:.6f} seconds")
Output
Prompt shape: torch.Size([10000, 50])
Result shape: torch.Size([10000, 150]), Time: 0.163375 seconds

We can see a significant speed up in terms of E2E latency. This is mainly driven by reducing the amount of FLOPs that we have to do for each token generation step. In the above code, we can see that there are two phases involved in generation. 1. The initial forward pass takes the entire user input, builds the KV cache, and produces the logits for the first output token. 2. The token generation loop then uses the KV cache to predict subsequent tokens one at a time. These two steps are the prefill and decode steps respectively. Prefill is usually compute bound because a large prompt can be processed in parallel with high arithmetic intensity. Decode is usually memory bandwidth bound because each step reads the growing KV cache and produces tokens sequentially. This distinction is one reason model providers often charge more for output tokens than input tokens

Note: One thing we are glossing over for the sake of simplicity is that when we start the decoding step, different requests can complete generation at different token lengths. In the above code, we strictly generate 100 tokens for all inputs but this would not be the case for real requests. In such cases, the decode step is optimized through continuous batching

Now that we have built the KV cache mechanism, we can use it to serve our LLM from a single GPU. On every new request we can alternate between prefill and decode steps to generate the output tokens and send them back to the user. However, this can hurt our TTFT (Time To First Token) metrics since requests with longer prefixes (prompts) can monopolize the GPU and hold up requests with smaller input prompts. Take a look at the visualization below

Swimlane timeline showing alternating prefill and decode TTFT impact A single GPU timeline where request B arrives during request A's long prefill. Request B waits and has inflated time to first token. Alternating prefill/decode hurts TTFT Req A prefill = 300ms. Req B arrives at 50ms with a short prefill = 20ms. GPU Req A (long prefix) Req B (short prefix) B arrives TTFT(B) = 300ms Long running prefill monopolizes the GPU and delays subsequent requests. 0ms 50 300 320 340 350 time (ms) Prefill Decode Waiting / blocked

To improve our overall TTFT metric, we can stop the decoding step when a new request arrives, perform prefill and resume decoding. This keeps TTFT more predictable for the new request, but hurts ITL (Inter-token Latency) for in-flight requests because their decode stream is paused while the prefill runs.

Swimlane timeline for decode preemption tradeoff Request A finishes prefill and starts decoding. When request B arrives, decode is paused for B prefill, then decode resumes. This keeps TTFT for B predictable but increases ITL for A. Pause decode for new prefill: predictable TTFT, worse ITL Req A first runs prefill, starts decoding, then gets preempted when Req B arrives and needs prefill. GPU Request A Request B B arrives TTFT(B) stays predictable (~60ms) ITL(A) increases while decode is paused 0ms A prefill done B prefill done A decode resumes relative time

In order to maintain predictable TTFT and ITL metrics, we can perform chunked prefills, i.e. prefill in chunks of fixed length so that both prefill and decode steps can be interleaved on the GPU. This can cause a slight increase in TTFT, but we can accept this tradeoff since we get more predictable latency across both steps and also better resource utilization.

Swimlane timeline for chunked prefill scheduling A long prefill request is split into chunks and interleaved with decode steps. This increases TTFT slightly for the long request but keeps ITL predictable for in-flight decode. Chunked prefill shared across requests Instead of one request monopolizing prefill, scheduler rotates fixed chunks across Req A / Req B / Req C. GPU Req A (long prefix) Req B (medium prefix) Req C (short prefix) B arrives C arrives A1 B1 C1 A2 B2 C2 A3 Small TTFT tax, but no single request monopolizes prefill Interleaving across A/B/C keeps decode cadence predictable 0ms round 1 round 2 round 3 decode relative time Prefill chunk Decode step Waiting slot

Now, say we have an 8 GPU node or a GPU cluster in a data center. Instead of running the prefill and decode steps on the same GPU, if we disaggregate and run these steps independently on dedicated GPUs we get the flexibility to scale them differently (due to workload differences) and as a result optimize our resource utilization. Splitting the workloads also reduces unpredictable spikes in TTFT and ITL metrics.

There is a tradeoff here though. When running prefill and decode on the same GPUs we could directly reuse the KV cache from HBM, however, if they are running on different GPUs then we have to move the KV cache around. And KV caches can get pretty big. For a Llama 70B-style model with FP16 KV cache, 80 layers, 8 KV heads, and 128-dimensional heads, KV cache storage is roughly 320KB per live token across the full model. Across long contexts and many concurrent sequences, that can add up to tens of GB; around 120K-240K live tokens would be roughly 40-80GB. Moving tens of GB of tensors can become the bottleneck in generating tokens fast, not the FLOPs we can do. NVIDIA's NVLink/NVSwitch can help us speed up the transfer of the KV cache blocks across GPUs within a node and we can also use RDMA to transfer cache blocks across nodes.

Disaggregated prefill and decode with KV transfer Dedicated prefill GPUs and decode GPUs run independently. KV cache blocks are transferred across RDMA fabric between the two pools. Disaggregated serving: dedicated prefill + decode GPUs Prefill and decode scale independently; KV cache moves over RDMA/NVLink fabric. Prefill GPU Pool Compute-bound prefill RDMA KV cache transfer Decode GPU Pool Memory-bound decode Predictable TTFT/ITL from independent scaling, with transfer overhead from KV movement. Prefill Decode RDMA KV transfer
Disaggregating prefill and decode lets each pool scale independently, while RDMA/NVLink carries KV cache blocks between them.

These optimizations allow us to serve our LLMs with higher throughput, predictable latency and better GPU utilization