vidur_mlsys24.pdf
LLM Inference
-
Prefill Phase: process entire user input prompt and initiates/updates the KV Cache → token logits are available
-
Token Sampling (uses sampling process to select the next new token)
-
Decode Phase — Transformer block containing Attention + MLP layer
- Attention: token to output next is used to calculate KV and append to the KV Cache
- MLP:
- Generate next token until EOS (end-of-sequence) is generated
- Requires access to key and value activations of previously processed tokens to perform attention
- Stored in KV Cache
-
Additional notes:
- There can be multiple Transformer blocks
- each are responsible for extracting different context for the same token
- each layer has its own KV Cache
- each layer produces its own token_x to serve as input to next transformer block
- There can be multiple Attention layers (run in parallel) within a single Transformer block
- each attention layer is responsible for focusing on different parts of the sequence for a given token
- each attention layer outputs a vector representing a token
- these tokens are concatenated then projected to expected dimensions for the MLP layer
- For a transformer block, and for a given input, X, (Q,K,V) matrix is generated by a learned matrix (W_q, W_k, W_v). This is then equally split by the number of attention heads in that layer.
- Q is disposed
- K and V are stored in the KV Cache
-
Optimization Techniques:
- TP / PP parallelism
- Prefill / Decode Priority Scheduling
- Note: Optimization technique is dependent on the model and workload
Challenges of LLM Inference Simulation
- must be accurate at much finer time granularities than training simulation
- An iteration of LLM training (forward, loss calc, backward) is hundreds of milliseconds while LLM inference is few milliseconds
- LLM inference is a prefill (one-time cost) and a decode stage that does not have as much computation
- Iteration time variation
- Different characteristics of computation (prefill : compute bound; decode: memory bound [KV Cache] )
- Varying request lengths (prompt lengths / decode tokens generated)
- system load or workload characteristics change
- interleaving of prefill and decode stages depend on scheduling strategy → lead to changes in iteration latency
- Inference is stateful, so small errors can cause cascading errors
- requests arrive in system dynamically and if run time prediction of any batch has errors, it can change batching pattern
VIDUR Design
- LLM share key architectural properties → can model small number of compute operators that are shared amongst models
- In a given model and a given running batch, each request is associated with different number of KV-Cache tokens and query tokens
- Two categories of LLM operators
- execution time dependent on total context length of all requests in the batch
- execution time dependent on total number of tokens in the current iteration
Decode Step - Attention + MLP
- Attention Kernel - dependent on request history (KV Cache)
- Current token to be processed (Query token) is compared against all previous tokens (Key vectors) to calculate the attention score
- Attention score: how relevant each previous token is to the current one
- Attention score is used to create a weighted average of the Value vectors for all previous tokens
- Output is a vector of fixed size that represents the token’s meaning with the given context
- Work needed to be done is proportional to the length of the KV Cache (request history)