- Show steps in inference
- Show how there are two parts to it, prefill and decode
- Importance of KV caching, show how perf improves as there are longer prefix matches
- Show how both are different workloads and having them disaggregated can help optimize. Doing both on same GPUs can mean underutilization
- In prefill mode you measure TTFT but in decode mode you want minimal tok/sec and you also want a lower overall value too
- Key point - what happens if during decoding, one completes before another ??? How can we do continuous batching
- You can do them on different GPU but must do it on the same node or at least in same DC to ensure that you are able to transfer the KV cache quickly
- Paged Attention - talk abuot how this reduces tail latency
- Attention architecture updates has reduced the amount of data you need to transfer without sacrificing perf
- Show how Kimi paper allows you to do that transfer across data center with higher throughput.
- Why do we even want to do this across data center
What happens when you type a message and hit enter on your ChatGPT app? What happens when you call Claude API?
If you are curious about how LLM inference works and how these model providers are able to serve millions of requests while efficiently utilizing their large GPU clusters, then I hope this blog sheds some light. Also, this post assumes that you are aware of transformers, self-attention and have written or seen some Pytorch code before.
Lets start with a simple toy LLM which is just an attention block. This helps us focus on the core levers that we need to look for when we attempting to optimize the inference of these models
import torch
import torch.nn as nn
import torch.nn.functional as Fdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class ToyLLM(nn.Module):
def __init__(
self,
vocab_size: int = 100,
dim: int = 32,
max_seq_len: int = 2048,
device: str = "cpu"
):
super().__init__()
self.tok_embed = nn.Embedding(vocab_size, dim, device=device)
self.pos_embed = nn.Embedding(max_seq_len, dim, device=device)
## Attention block
self.q_proj = nn.Linear(dim, dim, bias=False, device=device)
self.k_proj = nn.Linear(dim, dim, bias=False, device=device)
self.v_proj = nn.Linear(dim, dim, bias=False, device=device)
self.out = nn.Linear(dim, vocab_size, bias=False, device=device)
self.dim = dim
self.device = device
def forward(self, input_ids: torch.Tensor):
B, T = input_ids.shape
positions = torch.arange(T, device=self.device).unsqueeze(0)
tok_emb = self.tok_embed(input_ids)
pos_emb = self.pos_embed(positions)
x = tok_emb + pos_emb
## Calculate self-attention
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.dim ** 0.5)
mask = torch.tril(torch.ones(T, T, device=self.device)).bool()
scores = scores.masked_fill(~mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
hidden = torch.matmul(attn, v)
logits = self.out(hidden)
return logitsLets take our toy LLM and pass a tokenized prompt through it. The prompt here, maps to the user input, by doing a forward pass, the model returns an array of logits, i.e. at each input token position it predicts what is the most likely next token
model = ToyLLM(device=device)
model.eval()
prompt = torch.tensor([[1, 2, 3, 4]], device=device)
logits = model(prompt)
print(logits.shape)torch.Size([1, 4, 100])
In the above example, the user has already provided the first four tokens. So we are interested in the 5th token that is generated by the LLM, this can be obtained by taking the argmax on the last token genreated by the model.
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
prompt = torch.cat([prompt, next_token], dim=1)
print(prompt)tensor([[ 1, 2, 3, 4, 21]], device='cuda:0')
Now that we have our 5th taken, we append it to the original user input and do another forward pass through the model to get the 6th token. We keep iterating this process to generate more tokens until we hit our max_token generation limit or the model generates a special [EOS] token indicating that the generation is complete
logits = model(prompt)
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
prompt = torch.cat([prompt, next_token], dim=1)
print(prompt)tensor([[ 1, 2, 3, 4, 47, 4]], device='cuda:0')
If you notice in the above code, after generating a single new token, we append it the input and regenerate all the KV values in our attention block for our previous tokens which we already caculated in the previous forward pass. Or in other words we are wasting FLOPS. Lets do a simple timer to see how long it takes generate 100 tokens for a batch size of 10K (reported numbers are from a T4 GPU run in Google Colab)
import time
max_tokens = 100
prompt = torch.randint(low=0, high=100, size=(10000, 50), device=device)
print(f"Prompt shape: {prompt.shape}")
start = time.perf_counter()
for _ in range(max_tokens):
logits = model(prompt)
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
prompt = torch.cat([prompt, next_token], dim=1)
end = time.perf_counter()
print(f"Result shape: {prompt.shape}, Time: {end - start:.6f} seconds")Prompt shape: torch.Size([10000, 50]) Result shape: torch.Size([10000, 150]), Time: 4.052675 seconds
In an attempt to speed up the token generation process, lets create a new model. But this time, add KV caching. After each forward pass, the KV values from the previous is stored and reused. This means only for the new token N, we are calculating the attention scores. For 1..N-1 tokens, we simply reuse the KV values
class ToyLLMWithCaching(nn.Module):
def __init__(
self,
vocab_size: int = 100,
dim: int = 32,
max_seq_len: int = 2048,
device: str = "cpu"
) -> None:
super().__init__()
self.tok_embed = nn.Embedding(vocab_size, dim, device=device)
self.pos_embed = nn.Embedding(max_seq_len, dim, device=device)
self.q_proj = nn.Linear(dim, dim, bias=False, device=device)
self.k_proj = nn.Linear(dim, dim, bias=False, device=device)
self.v_proj = nn.Linear(dim, dim, bias=False, device=device)
self.out = nn.Linear(dim, vocab_size, bias=False, device=device)
self.dim = dim
self.device = device
def forward(
self,
input_ids: torch.Tensor,
kv_cache: dict[str, torch.Tensor] | None = None
):
B, T = input_ids.shape
past_len = 0 if kv_cache is None else kv_cache["k"].shape[1]
positions = torch.arange(past_len, past_len + T, device=self.device).unsqueeze(0)
tok_emb = self.tok_embed(input_ids)
pos_emb = self.pos_embed(positions)
x = tok_emb + pos_emb
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
if kv_cache is not None:
k = torch.cat([kv_cache["k"], k], dim=1)
v = torch.cat([kv_cache["v"], v], dim=1)
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.dim ** 0.5)
if kv_cache is None:
mask = torch.tril(torch.ones(T, T, device=self.device)).bool()
scores = scores.masked_fill(~mask, float("-inf"))
else:
q_pos = torch.arange(past_len, past_len + T, device=self.device).unsqueeze(1)
k_pos = torch.arange(k.shape[1], device=self.device).unsqueeze(0)
mask = (k_pos <= q_pos).bool()
scores = scores.masked_fill(~mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
hidden = torch.matmul(attn, v)
logits = self.out(hidden)
cache = {"k": k, "v": v}
return logits, cachemodel = ToyLLMWithCaching(device=device)
model.eval()ToyLLMWithCaching( (tok_embed): Embedding(100, 32) (pos_embed): Embedding(2048, 32) (q_proj): Linear(in_features=32, out_features=32, bias=False) (k_proj): Linear(in_features=32, out_features=32, bias=False) (v_proj): Linear(in_features=32, out_features=32, bias=False) (out): Linear(in_features=32, out_features=100, bias=False) )
import time
@torch.no_grad()
def run_model(model, prompt, max_tokens):
logits, cache = model(prompt) ## Prefill
generated = []
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
generated.append(next_token)
for _ in range(max_tokens-1):
logits, cache = model(next_token, cache) ## Decode
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
generated.append(next_token)
result = torch.cat([prompt] + generated, dim=1)
return result
max_tokens = 100
prompt = torch.randint(low=0, high=100, size=(10000, 50), device=device)
print(f"Prompt shape: {prompt.shape}")
start = time.perf_counter()
result = run_model(model, prompt, max_tokens)
end = time.perf_counter()
print(f"Result shape: {result.shape}, Time: {end - start:.6f} seconds")Prompt shape: torch.Size([10000, 50]) Result shape: torch.Size([10000, 150]), Time: 0.334071 seconds
We see significant speed up in terms of E2E latency. This is mainly driven by reducing the amount of FLOPS that you have to do for each token generation step. In the above code, you can see that there are two phases in generating a token. The initial forward pass which takes the user input and builds the KV cache. The token generation step which uses the KV cache to predict the next token. These two steps are the prefill and decode steps respectively. The prefill step is compute bound. Its FLOPS done in parallel and is affected by larger prompts. Decode step is memory bound, token generation becomes more expensive as the KV cache grows.
Lets run a simple benchmark to profile the prefill and decode steps in our Toy LLM
Note: One thing we are glossing over for the sake of simplicty is that when we start the decoding step, different requests can complete generation at different lengths. However, we still do not want to continue decoding in that batch. This is optimized through continuous batching which we will get to in another post probably
import time
import torch
from torch.profiler import profile, ProfilerActivity
device = "cuda"
model = ToyLLMWithCaching(device=device).eval()
batch_size = 50000
prompt_len = 50
max_tokens = 100
prompt = torch.randint(low=0, high=100, size=(batch_size, prompt_len), device=device)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=False,
) as prof_prefill:
with torch.no_grad():
start = time.perf_counter()
logits, cache = model(prompt)
torch.cuda.synchronize()
end = time.perf_counter()
prefill_time = end - start
print("\n=== PREFILL ===")
print(f"time: {prefill_time:.6f} s")
print(f"max CUDA memory allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
print(
prof_prefill.key_averages().table(
sort_by="cuda_time_total",
row_limit=10
)
)
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=False,
) as prof_decode:
with torch.no_grad():
start = time.perf_counter()
for _ in range(max_tokens - 1):
logits, cache = model(next_token, cache)
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
torch.cuda.synchronize()
end = time.perf_counter()
decode_time = end - start
print("\n=== DECODE ===")
print(f"time: {decode_time:.6f} s")
print(f"max CUDA memory allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
print(
prof_decode.key_averages().table(
sort_by="cuda_time_total",
row_limit=10
)
)
=== PREFILL ===
time: 0.076620 s
max CUDA memory allocated: 5629.88 MB
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg CPU Mem Self CPU Mem CUDA Mem Self CUDA Mem # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::matmul 0.18% 133.959us 4.89% 3.717ms 619.523us 0.000us 0.00% 29.532ms 4.922ms 0 B 0 B 2.59 GB 0 B 6
aten::bmm 0.16% 119.914us 1.41% 1.074ms 537.044us 18.164ms 36.57% 18.164ms 9.082ms 0 B 0 B 782.84 MB 782.84 MB 2
aten::linear 0.04% 28.435us 3.44% 2.614ms 653.479us 0.000us 0.00% 11.368ms 2.842ms 0 B 0 B 1.83 GB 0 B 4
aten::mm 0.64% 483.104us 3.14% 2.386ms 596.467us 11.368ms 22.89% 11.368ms 2.842ms 0 B 0 B 1.83 GB 1.83 GB 4
volta_sgemm_128x128_tn 0.00% 0.000us 0.00% 0.000us 0.000us 10.637ms 21.42% 10.637ms 3.546ms 0 B 0 B 0 B 0 B 3
volta_sgemm_64x64_nn 0.00% 0.000us 0.00% 0.000us 0.000us 10.231ms 20.60% 10.231ms 10.231ms 0 B 0 B 0 B 0 B 1
aten::masked_fill 0.02% 17.215us 0.17% 127.410us 127.410us 0.000us 0.00% 9.348ms 9.348ms 0 B 0 B 476.84 MB 0 B 1
volta_sgemm_64x64_tn 0.00% 0.000us 0.00% 0.000us 0.000us 7.933ms 15.97% 7.933ms 7.933ms 0 B 0 B 0 B 0 B 1
aten::softmax 0.01% 6.145us 0.06% 45.025us 45.025us 0.000us 0.00% 6.058ms 6.058ms 0 B 0 B 476.84 MB 0 B 1
aten::_softmax 0.03% 19.782us 0.05% 38.880us 38.880us 6.058ms 12.20% 6.058ms 6.058ms 0 B 0 B 476.84 MB 476.84 MB 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 75.941ms
Self CUDA time total: 49.666ms
=== DECODE ===
time: 2.109107 s
max CUDA memory allocated: 3794.50 MB
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg CPU Mem Self CPU Mem CUDA Mem Self CUDA Mem # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::cat 0.23% 4.837ms 85.85% 1.781s 8.994ms 1.240s 63.33% 1.259s 6.360ms 0 B 0 B 118.03 GB 118.03 GB 198
void at::native::(anonymous namespace)::CatArrayBatc... 0.00% 0.000us 0.00% 0.000us 0.000us 1.240s 63.33% 1.240s 6.326ms 0 B 0 B 0 B 0 B 196
aten::matmul 0.35% 7.303ms 2.90% 60.241ms 101.415us 0.000us 0.00% 625.961ms 1.054ms 0 B 0 B 6.15 GB 0 B 594
aten::bmm 0.37% 7.739ms 1.00% 20.801ms 105.057us 569.914ms 29.11% 587.082ms 2.965ms 0 B 0 B 2.46 GB 2.46 GB 198
std::enable_if<!(false), void>::type internal::gemvx... 0.00% 0.000us 0.00% 0.000us 0.000us 292.407ms 14.93% 292.407ms 2.984ms 0 B 0 B 0 B 0 B 98
std::enable_if<!(false), void>::type internal::gemvx... 0.00% 0.000us 0.00% 0.000us 0.000us 277.507ms 14.17% 277.507ms 2.803ms 0 B 0 B 0 B 0 B 99
aten::linear 0.08% 1.603ms 1.76% 36.437ms 92.013us 0.000us 0.00% 38.879ms 98.179us 0 B 0 B 3.69 GB 0 B 396
aten::mm 0.63% 13.160ms 1.17% 24.342ms 61.469us 38.209ms 1.95% 38.879ms 98.179us 0 B 0 B 3.69 GB 3.69 GB 396
cudaLaunchKernel 1.13% 23.448ms 2.59% 53.717ms 27.130us 0.000us 0.00% 38.241ms 19.314us 0 B 0 B 0 B 0 B 1980
Command Buffer Full 1.46% 30.269ms 1.46% 30.269ms 1.593ms 38.241ms 1.95% 38.241ms 2.013ms 0 B 0 B 0 B 0 B 19
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 2.074s
Self CUDA time total: 1.958sSay we have a 8 GPU node and we are looking to serve our LLM from this node. If we alternate between prefill and decode runs on the entire node we are going to hurt our TTFT and ITL metrics because in-flight requests will hold up the token generation process, our users won't be happy. Our boss won't be happy as well cause this will lead to under-utilization of the GPUs
There are two key changes we can do implement here to improve: 1) Prefill/Decode disaggregation 2) Chunked prefill
PD disaggregation allows us to ensure that you are able to maintain a predictable ITL, allows you to split the workloads and efficiently utilize your GPUs. However, there is one gotcha. When you disaggregate prefill and decode you to different GPUs, you need to share the KV cache and this is can sped up through RDMAs but it is costly. Also, there is a tradeoff here, the TTFT increases because you now have to account for RDMA transfer as well before you could start decoding and get the first token.
Chunked prefill ensures that a single long request does not monolopize the compute. It basically splits the tokens into chunks and performs the prefill. This basically translates into multiple forward passes through the model to generate the KV cache for all tokens and it interleaves the forward passes with other batches so that requests of all lengths can make progress
Okay, we do this in a single node, what about across nodes and how do you even do this across data centers ???