What happens when you type a message and hit enter on your ChatGPT app? What happens when you call Claude API?
Lets dive in and take a look at how LLM inference is performed at large scale. Also, this post assumes that you have some background in transformers, self-attention and Pytorch.
Lets start with a simple toy LLM with just an attention block. This helps us focus on the core levers that we need to look for when attempting to optimize the inference of these models
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ToyLLM(nn.Module):
def __init__(
self,
vocab_size: int = 100,
dim: int = 32,
max_seq_len: int = 2048,
device: str = "cpu"
):
super().__init__()
self.tok_embed = nn.Embedding(vocab_size, dim, device=device)
self.pos_embed = nn.Embedding(max_seq_len, dim, device=device)
## Attention block
self.q_proj = nn.Linear(dim, dim, bias=False, device=device)
self.k_proj = nn.Linear(dim, dim, bias=False, device=device)
self.v_proj = nn.Linear(dim, dim, bias=False, device=device)
self.out = nn.Linear(dim, vocab_size, bias=False, device=device)
self.dim = dim
self.device = device
def forward(self, input_ids: torch.Tensor):
B, T = input_ids.shape
positions = torch.arange(T, device=self.device).unsqueeze(0)
tok_emb = self.tok_embed(input_ids)
pos_emb = self.pos_embed(positions)
x = tok_emb + pos_emb
## Calculate self-attention
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.dim ** 0.5)
mask = torch.tril(torch.ones(T, T, device=self.device)).bool()
scores = scores.masked_fill(~mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
hidden = torch.matmul(attn, v)
logits = self.out(hidden)
return logitsLets take our toy LLM and pass a dummy tokenized prompt through it. The prompt here maps to a user input, an instruction like "What is the weather today?", "Generate code" etc. A forward pass through the model returns an array of logits, i.e. each position predicts the next token for the prefix ending at that position, or in other words the model is constantly trying to predict what comes next.
model = ToyLLM(device=device)
model.eval()
prompt = torch.tensor([[1, 2, 3, 4]], device=device)
logits = model(prompt)
print(logits.shape)torch.Size([1, 4, 100])
In the above example, we have already provided the first four tokens. So we are interested in the 5th token in the sequence, which is the first token generated by the LLM. This can be obtained by taking the argmax on the logits at the last prompt position. Doing an argmax allows us to greedily pick the next token which has the highest probability
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
prompt = torch.cat([prompt, next_token], dim=1)
print(prompt)tensor([[ 1, 2, 3, 4, 14]], device='cuda:0')
Now that we have our 5th token, we can append it to the original input tokens and do another forward pass through the model to get the 6th token. We can keep iterating this process to generate more tokens until we hit our max_tokens generation limit or the model generates a special [EOS] token indicating that the generation is complete. This is how LLMs generate tokens in an auto-regressive fashion.
logits = model(prompt)
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
prompt = torch.cat([prompt, next_token], dim=1)
print(prompt)tensor([[ 1, 2, 3, 4, 14, 14]], device='cuda:0')
If you notice in the above code, after generating a new token, when we append it to the input to predict the next token, we regenerate all the KV values in our attention block that were computed in the previous forward pass. We keep re-computing the KV values for each forward pass instead of reusing them, or in other words we are wasting FLOPs. Lets do a simple timer to see how long it takes generate 100 tokens for a batch size of 10K (reported numbers are from a T4 GPU)
import time
max_tokens = 100
prompt = torch.randint(low=0, high=100, size=(10000, 50), device=device)
print(f"Prompt shape: {prompt.shape}")
with torch.no_grad():
start = time.perf_counter()
for _ in range(max_tokens):
logits = model(prompt)
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
prompt = torch.cat([prompt, next_token], dim=1)
end = time.perf_counter()
print(f"Result shape: {prompt.shape}, Time: {end - start:.6f} seconds")Prompt shape: torch.Size([10000, 50]) Result shape: torch.Size([10000, 150]), Time: 3.579580 seconds
Instead of re-computing the KV values, lets create a new version of the model which reuses the KV values computed in the previous forward pass. For the new token N, we compute fresh Q/K/V projections and calculate attention for the new query against all cached keys. For 1..N-1 tokens, we reuse the cached K/V tensors instead of recomputing them
Transformer Block
KV Cache
2 rows
KV1
appendKV2
appendInput Tokens
parallel x2
The
model
class ToyLLMWithCaching(nn.Module):
def __init__(
self,
vocab_size: int = 100,
dim: int = 32,
max_seq_len: int = 2048,
device: str = "cpu"
) -> None:
super().__init__()
self.tok_embed = nn.Embedding(vocab_size, dim, device=device)
self.pos_embed = nn.Embedding(max_seq_len, dim, device=device)
self.q_proj = nn.Linear(dim, dim, bias=False, device=device)
self.k_proj = nn.Linear(dim, dim, bias=False, device=device)
self.v_proj = nn.Linear(dim, dim, bias=False, device=device)
self.out = nn.Linear(dim, vocab_size, bias=False, device=device)
self.dim = dim
self.device = device
def forward(
self,
input_ids: torch.Tensor,
kv_cache: dict[str, torch.Tensor] | None = None
):
B, T = input_ids.shape
past_len = 0 if kv_cache is None else kv_cache["k"].shape[1]
positions = torch.arange(past_len, past_len + T, device=self.device).unsqueeze(0)
tok_emb = self.tok_embed(input_ids)
pos_emb = self.pos_embed(positions)
x = tok_emb + pos_emb
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
if kv_cache is not None:
k = torch.cat([kv_cache["k"], k], dim=1)
v = torch.cat([kv_cache["v"], v], dim=1)
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.dim ** 0.5)
if kv_cache is None:
mask = torch.tril(torch.ones(T, T, device=self.device)).bool()
scores = scores.masked_fill(~mask, float("-inf"))
else:
q_pos = torch.arange(past_len, past_len + T, device=self.device).unsqueeze(1)
k_pos = torch.arange(k.shape[1], device=self.device).unsqueeze(0)
mask = (k_pos <= q_pos).bool()
scores = scores.masked_fill(~mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
hidden = torch.matmul(attn, v)
logits = self.out(hidden)
cache = {"k": k, "v": v}
return logits, cache
model = ToyLLMWithCaching(device=device)
model.eval()ToyLLMWithCaching( (tok_embed): Embedding(100, 32) (pos_embed): Embedding(2048, 32) (q_proj): Linear(in_features=32, out_features=32, bias=False) (k_proj): Linear(in_features=32, out_features=32, bias=False) (v_proj): Linear(in_features=32, out_features=32, bias=False) (out): Linear(in_features=32, out_features=100, bias=False) )
Lets time how long it takes to process the same input size as before when we reuse the KV values (a.k.a use KV cache)
import time
@torch.no_grad()
def run_model(model, prompt, max_tokens):
logits, cache = model(prompt) ## Prefill
generated = []
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
generated.append(next_token)
for _ in range(max_tokens-1):
logits, cache = model(next_token, cache) ## Decode
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
generated.append(next_token)
result = torch.cat([prompt] + generated, dim=1)
return result
max_tokens = 100
prompt = torch.randint(low=0, high=100, size=(10000, 50), device=device)
print(f"Prompt shape: {prompt.shape}")
start = time.perf_counter()
result = run_model(model, prompt, max_tokens)
end = time.perf_counter()
print(f"Result shape: {result.shape}, Time: {end - start:.6f} seconds")Prompt shape: torch.Size([10000, 50]) Result shape: torch.Size([10000, 150]), Time: 0.163375 seconds
We can see a significant speed up in terms of E2E latency. This is mainly driven by reducing the amount of FLOPs that we have to do for each token generation step. In the above code, we can see that there are two phases involved in generation. 1. The initial forward pass takes the entire user input, builds the KV cache, and produces the logits for the first output token. 2. The token generation loop then uses the KV cache to predict subsequent tokens one at a time. These two steps are the prefill and decode steps respectively. Prefill is usually compute bound because a large prompt can be processed in parallel with high arithmetic intensity. Decode is usually memory bandwidth bound because each step reads the growing KV cache and produces tokens sequentially. This distinction is one reason model providers often charge more for output tokens than input tokens
Note: One thing we are glossing over for the sake of simplicity is that when we start the decoding step, different requests can complete generation at different token lengths. In the above code, we strictly generate 100 tokens for all inputs but this would not be the case for real requests. In such cases, the decode step is optimized through continuous batching
Now that we have built the KV cache mechanism, we can use it to serve our LLM from a single GPU. On every new request we can alternate between prefill and decode steps to generate the output tokens and send them back to the user. However, this can hurt our TTFT (Time To First Token) metrics since requests with longer prefixes (prompts) can monopolize the GPU and hold up requests with smaller input prompts. Take a look at the visualization below
To improve our overall TTFT metric, we can stop the decoding step when a new request arrives, perform prefill and resume decoding. This keeps TTFT more predictable for the new request, but hurts ITL (Inter-token Latency) for in-flight requests because their decode stream is paused while the prefill runs.
In order to maintain predictable TTFT and ITL metrics, we can perform chunked prefills, i.e. prefill in chunks of fixed length so that both prefill and decode steps can be interleaved on the GPU. This can cause a slight increase in TTFT, but we can accept this tradeoff since we get more predictable latency across both steps and also better resource utilization.
Now, say we have an 8 GPU node or a GPU cluster in a data center. Instead of running the prefill and decode steps on the same GPU, if we disaggregate and run these steps independently on dedicated GPUs we get the flexibility to scale them differently (due to workload differences) and as a result optimize our resource utilization. Splitting the workloads also reduces unpredictable spikes in TTFT and ITL metrics.
There is a tradeoff here though. When running prefill and decode on the same GPUs we could directly reuse the KV cache from HBM, however, if they are running on different GPUs then we have to move the KV cache around. And KV caches can get pretty big. For a Llama 70B-style model with FP16 KV cache, 80 layers, 8 KV heads, and 128-dimensional heads, KV cache storage is roughly 320KB per live token across the full model. Across long contexts and many concurrent sequences, that can add up to tens of GB; around 120K-240K live tokens would be roughly 40-80GB. Moving tens of GB of tensors can become the bottleneck in generating tokens fast, not the FLOPs we can do. NVIDIA's NVLink/NVSwitch can help us speed up the transfer of the KV cache blocks across GPUs within a node and we can also use RDMA to transfer cache blocks across nodes.
These optimizations allow us to serve our LLMs with higher throughput, predictable latency and better GPU utilization