Prefill Decode Primer

Notebook export rendered with runnable code cells.

  • Show steps in inference
  • Show how there are two parts to it, prefill and decode
  • Importance of KV caching, show how perf improves as there are longer prefix matches
  • Show how both are different workloads and having them disaggregated can help optimize. Doing both on same GPUs can mean underutilization
  • In prefill mode you measure TTFT but in decode mode you want minimal tok/sec and you also want a lower overall value too
  • Key point - what happens if during decoding, one completes before another ??? How can we do continuous batching
  • You can do them on different GPU but must do it on the same node or at least in same DC to ensure that you are able to transfer the KV cache quickly
  • Paged Attention - talk abuot how this reduces tail latency
  • Attention architecture updates has reduced the amount of data you need to transfer without sacrificing perf
  • Show how Kimi paper allows you to do that transfer across data center with higher throughput.
  • Why do we even want to do this across data center

What happens when you type a message and hit enter on your ChatGPT app? What happens when you call Claude API?

If you are curious about how LLM inference works and how these model providers are able to serve millions of requests while efficiently utilizing their large GPU clusters, then I hope this blog sheds some light. Also, this post assumes that you are aware of transformers, self-attention and have written or seen some Pytorch code before.

Lets start with a simple toy LLM which is just an attention block. This helps us focus on the core levers that we need to look for when we attempting to optimize the inference of these models

Codepython
import torch
import torch.nn as nn
import torch.nn.functional as F
Codepython
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Codepython
class ToyLLM(nn.Module):
    def __init__(
        self,
        vocab_size: int = 100,
        dim: int = 32,
        max_seq_len: int = 2048,
        device: str = "cpu"
    ):
        super().__init__()
        self.tok_embed = nn.Embedding(vocab_size, dim, device=device)
        self.pos_embed = nn.Embedding(max_seq_len, dim, device=device)

        ## Attention block
        self.q_proj = nn.Linear(dim, dim, bias=False, device=device)
        self.k_proj = nn.Linear(dim, dim, bias=False, device=device)
        self.v_proj = nn.Linear(dim, dim, bias=False, device=device)

        self.out = nn.Linear(dim, vocab_size, bias=False, device=device)

        self.dim = dim
        self.device = device

    def forward(self, input_ids: torch.Tensor):
        B, T = input_ids.shape

        positions = torch.arange(T, device=self.device).unsqueeze(0)
        tok_emb = self.tok_embed(input_ids)
        pos_emb = self.pos_embed(positions)
        x = tok_emb + pos_emb

        ## Calculate self-attention
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.dim ** 0.5)
        mask = torch.tril(torch.ones(T, T, device=self.device)).bool()
        scores = scores.masked_fill(~mask, float("-inf"))

        attn = F.softmax(scores, dim=-1)

        hidden = torch.matmul(attn, v)
        logits = self.out(hidden)

        return logits

Lets take our toy LLM and pass a tokenized prompt through it. The prompt here, maps to the user input, by doing a forward pass, the model returns an array of logits, i.e. at each input token position it predicts what is the most likely next token

Codepython
model = ToyLLM(device=device)
model.eval()

prompt = torch.tensor([[1, 2, 3, 4]], device=device)
logits = model(prompt)
print(logits.shape)
Output
torch.Size([1, 4, 100])

In the above example, the user has already provided the first four tokens. So we are interested in the 5th token that is generated by the LLM, this can be obtained by taking the argmax on the last token genreated by the model.

Codepython
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
prompt = torch.cat([prompt, next_token], dim=1)
print(prompt)
Output
tensor([[ 1,  2,  3,  4, 21]], device='cuda:0')

Now that we have our 5th taken, we append it to the original user input and do another forward pass through the model to get the 6th token. We keep iterating this process to generate more tokens until we hit our max_token generation limit or the model generates a special [EOS] token indicating that the generation is complete

Codepython
logits = model(prompt)
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
prompt = torch.cat([prompt, next_token], dim=1)
print(prompt)
Output
tensor([[ 1,  2,  3,  4, 47,  4]], device='cuda:0')

If you notice in the above code, after generating a single new token, we append it the input and regenerate all the KV values in our attention block for our previous tokens which we already caculated in the previous forward pass. Or in other words we are wasting FLOPS. Lets do a simple timer to see how long it takes generate 100 tokens for a batch size of 10K (reported numbers are from a T4 GPU run in Google Colab)

Codepython
import time

max_tokens = 100
prompt = torch.randint(low=0, high=100, size=(10000, 50), device=device)
print(f"Prompt shape: {prompt.shape}")
start = time.perf_counter()
for _ in range(max_tokens):
    logits = model(prompt)
    next_token_logits = logits[:, -1, :]
    next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
    prompt = torch.cat([prompt, next_token], dim=1)
end = time.perf_counter()
print(f"Result shape: {prompt.shape}, Time: {end - start:.6f} seconds")
Output
Prompt shape: torch.Size([10000, 50])
Result shape: torch.Size([10000, 150]), Time: 4.052675 seconds

In an attempt to speed up the token generation process, lets create a new model. But this time, add KV caching. After each forward pass, the KV values from the previous is stored and reused. This means only for the new token N, we are calculating the attention scores. For 1..N-1 tokens, we simply reuse the KV values

Codepython
class ToyLLMWithCaching(nn.Module):
    def __init__(
        self,
        vocab_size: int = 100,
        dim: int = 32,
        max_seq_len: int = 2048,
        device: str = "cpu"
    ) -> None:
        super().__init__()
        self.tok_embed = nn.Embedding(vocab_size, dim, device=device)
        self.pos_embed = nn.Embedding(max_seq_len, dim, device=device)

        self.q_proj = nn.Linear(dim, dim, bias=False, device=device)
        self.k_proj = nn.Linear(dim, dim, bias=False, device=device)
        self.v_proj = nn.Linear(dim, dim, bias=False, device=device)

        self.out = nn.Linear(dim, vocab_size, bias=False, device=device)

        self.dim = dim
        self.device = device

    def forward(
        self,
        input_ids: torch.Tensor,
        kv_cache: dict[str, torch.Tensor] | None = None
    ):
        B, T = input_ids.shape
        past_len = 0 if kv_cache is None else kv_cache["k"].shape[1]

        positions = torch.arange(past_len, past_len + T, device=self.device).unsqueeze(0)
        tok_emb = self.tok_embed(input_ids)
        pos_emb = self.pos_embed(positions)
        x = tok_emb + pos_emb

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        if kv_cache is not None:
            k = torch.cat([kv_cache["k"], k], dim=1)
            v = torch.cat([kv_cache["v"], v], dim=1)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.dim ** 0.5)

        if kv_cache is None:
            mask = torch.tril(torch.ones(T, T, device=self.device)).bool()
            scores = scores.masked_fill(~mask, float("-inf"))
        else:
            q_pos = torch.arange(past_len, past_len + T, device=self.device).unsqueeze(1)
            k_pos = torch.arange(k.shape[1], device=self.device).unsqueeze(0)
            mask = (k_pos <= q_pos).bool()
            scores = scores.masked_fill(~mask, float("-inf"))

        attn = F.softmax(scores, dim=-1)

        hidden = torch.matmul(attn, v)
        logits = self.out(hidden)
        cache = {"k": k, "v": v}

        return logits, cache
Codepython
model = ToyLLMWithCaching(device=device)
model.eval()
Output
ToyLLMWithCaching(
  (tok_embed): Embedding(100, 32)
  (pos_embed): Embedding(2048, 32)
  (q_proj): Linear(in_features=32, out_features=32, bias=False)
  (k_proj): Linear(in_features=32, out_features=32, bias=False)
  (v_proj): Linear(in_features=32, out_features=32, bias=False)
  (out): Linear(in_features=32, out_features=100, bias=False)
)
Codepython
import time

@torch.no_grad()
def run_model(model, prompt, max_tokens):
    logits, cache = model(prompt) ## Prefill
    generated = []
    next_token_logits = logits[:, -1, :]
    next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
    generated.append(next_token)

    for _ in range(max_tokens-1):
        logits, cache = model(next_token, cache) ## Decode
        next_token_logits = logits[:, -1, :]
        next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        generated.append(next_token)
    result = torch.cat([prompt] + generated, dim=1)
    return result

max_tokens = 100
prompt = torch.randint(low=0, high=100, size=(10000, 50), device=device)
print(f"Prompt shape: {prompt.shape}")

start = time.perf_counter()
result = run_model(model, prompt, max_tokens)
end = time.perf_counter()
print(f"Result shape: {result.shape}, Time: {end - start:.6f} seconds")
Output
Prompt shape: torch.Size([10000, 50])
Result shape: torch.Size([10000, 150]), Time: 0.334071 seconds

We see significant speed up in terms of E2E latency. This is mainly driven by reducing the amount of FLOPS that you have to do for each token generation step. In the above code, you can see that there are two phases in generating a token. The initial forward pass which takes the user input and builds the KV cache. The token generation step which uses the KV cache to predict the next token. These two steps are the prefill and decode steps respectively. The prefill step is compute bound. Its FLOPS done in parallel and is affected by larger prompts. Decode step is memory bound, token generation becomes more expensive as the KV cache grows.

Lets run a simple benchmark to profile the prefill and decode steps in our Toy LLM

Note: One thing we are glossing over for the sake of simplicty is that when we start the decoding step, different requests can complete generation at different lengths. However, we still do not want to continue decoding in that batch. This is optimized through continuous batching which we will get to in another post probably

Codepython
import time
import torch
from torch.profiler import profile, ProfilerActivity

device = "cuda"
model = ToyLLMWithCaching(device=device).eval()

batch_size = 50000
prompt_len = 50
max_tokens = 100

prompt = torch.randint(low=0, high=100, size=(batch_size, prompt_len), device=device)

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
    with_stack=False,
) as prof_prefill:
    with torch.no_grad():
        start = time.perf_counter()
        logits, cache = model(prompt)
        torch.cuda.synchronize()
        end = time.perf_counter()

prefill_time = end - start

print("\n=== PREFILL ===")
print(f"time: {prefill_time:.6f} s")
print(f"max CUDA memory allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")

print(
    prof_prefill.key_averages().table(
        sort_by="cuda_time_total",
        row_limit=10
    )
)


next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    profile_memory=True,
    with_stack=False,
) as prof_decode:
    with torch.no_grad():
        start = time.perf_counter()
        for _ in range(max_tokens - 1):
            logits, cache = model(next_token, cache)
            next_token_logits = logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        torch.cuda.synchronize()
        end = time.perf_counter()

decode_time = end - start

print("\n=== DECODE ===")
print(f"time: {decode_time:.6f} s")
print(f"max CUDA memory allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")

print(
    prof_decode.key_averages().table(
        sort_by="cuda_time_total",
        row_limit=10
    )
)
Output

=== PREFILL ===
time: 0.076620 s
max CUDA memory allocated: 5629.88 MB
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::matmul         0.18%     133.959us         4.89%       3.717ms     619.523us       0.000us         0.00%      29.532ms       4.922ms           0 B           0 B       2.59 GB           0 B             6  
                                              aten::bmm         0.16%     119.914us         1.41%       1.074ms     537.044us      18.164ms        36.57%      18.164ms       9.082ms           0 B           0 B     782.84 MB     782.84 MB             2  
                                           aten::linear         0.04%      28.435us         3.44%       2.614ms     653.479us       0.000us         0.00%      11.368ms       2.842ms           0 B           0 B       1.83 GB           0 B             4  
                                               aten::mm         0.64%     483.104us         3.14%       2.386ms     596.467us      11.368ms        22.89%      11.368ms       2.842ms           0 B           0 B       1.83 GB       1.83 GB             4  
                                 volta_sgemm_128x128_tn         0.00%       0.000us         0.00%       0.000us       0.000us      10.637ms        21.42%      10.637ms       3.546ms           0 B           0 B           0 B           0 B             3  
                                   volta_sgemm_64x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us      10.231ms        20.60%      10.231ms      10.231ms           0 B           0 B           0 B           0 B             1  
                                      aten::masked_fill         0.02%      17.215us         0.17%     127.410us     127.410us       0.000us         0.00%       9.348ms       9.348ms           0 B           0 B     476.84 MB           0 B             1  
                                   volta_sgemm_64x64_tn         0.00%       0.000us         0.00%       0.000us       0.000us       7.933ms        15.97%       7.933ms       7.933ms           0 B           0 B           0 B           0 B             1  
                                          aten::softmax         0.01%       6.145us         0.06%      45.025us      45.025us       0.000us         0.00%       6.058ms       6.058ms           0 B           0 B     476.84 MB           0 B             1  
                                         aten::_softmax         0.03%      19.782us         0.05%      38.880us      38.880us       6.058ms        12.20%       6.058ms       6.058ms           0 B           0 B     476.84 MB     476.84 MB             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 75.941ms
Self CUDA time total: 49.666ms


=== DECODE ===
time: 2.109107 s
max CUDA memory allocated: 3794.50 MB
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                              aten::cat         0.23%       4.837ms        85.85%        1.781s       8.994ms        1.240s        63.33%        1.259s       6.360ms           0 B           0 B     118.03 GB     118.03 GB           198  
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us        1.240s        63.33%        1.240s       6.326ms           0 B           0 B           0 B           0 B           196  
                                           aten::matmul         0.35%       7.303ms         2.90%      60.241ms     101.415us       0.000us         0.00%     625.961ms       1.054ms           0 B           0 B       6.15 GB           0 B           594  
                                              aten::bmm         0.37%       7.739ms         1.00%      20.801ms     105.057us     569.914ms        29.11%     587.082ms       2.965ms           0 B           0 B       2.46 GB       2.46 GB           198  
std::enable_if<!(false), void>::type internal::gemvx...         0.00%       0.000us         0.00%       0.000us       0.000us     292.407ms        14.93%     292.407ms       2.984ms           0 B           0 B           0 B           0 B            98  
std::enable_if<!(false), void>::type internal::gemvx...         0.00%       0.000us         0.00%       0.000us       0.000us     277.507ms        14.17%     277.507ms       2.803ms           0 B           0 B           0 B           0 B            99  
                                           aten::linear         0.08%       1.603ms         1.76%      36.437ms      92.013us       0.000us         0.00%      38.879ms      98.179us           0 B           0 B       3.69 GB           0 B           396  
                                               aten::mm         0.63%      13.160ms         1.17%      24.342ms      61.469us      38.209ms         1.95%      38.879ms      98.179us           0 B           0 B       3.69 GB       3.69 GB           396  
                                       cudaLaunchKernel         1.13%      23.448ms         2.59%      53.717ms      27.130us       0.000us         0.00%      38.241ms      19.314us           0 B           0 B           0 B           0 B          1980  
                                    Command Buffer Full         1.46%      30.269ms         1.46%      30.269ms       1.593ms      38.241ms         1.95%      38.241ms       2.013ms           0 B           0 B           0 B           0 B            19  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.074s
Self CUDA time total: 1.958s

Say we have a 8 GPU node and we are looking to serve our LLM from this node. If we alternate between prefill and decode runs on the entire node we are going to hurt our TTFT and ITL metrics because in-flight requests will hold up the token generation process, our users won't be happy. Our boss won't be happy as well cause this will lead to under-utilization of the GPUs

There are two key changes we can do implement here to improve: 1) Prefill/Decode disaggregation 2) Chunked prefill

PD disaggregation allows us to ensure that you are able to maintain a predictable ITL, allows you to split the workloads and efficiently utilize your GPUs. However, there is one gotcha. When you disaggregate prefill and decode you to different GPUs, you need to share the KV cache and this is can sped up through RDMAs but it is costly. Also, there is a tradeoff here, the TTFT increases because you now have to account for RDMA transfer as well before you could start decoding and get the first token.

Chunked prefill ensures that a single long request does not monolopize the compute. It basically splits the tokens into chunks and performs the prefill. This basically translates into multiple forward passes through the model to generate the KV cache for all tokens and it interleaves the forward passes with other batches so that requests of all lengths can make progress

Okay, we do this in a single node, what about across nodes and how do you even do this across data centers ???