Efficient Attention at Scale: KV Cache, GQA & FlashAttention
Canonical Papers
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
Read paper →GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
Read paper →SnapKV: LLM Knows What You Are Looking for Before Generation
Read paper →Core Mathematics
Modern LLMs are defined not just by *what* attention is, but by how attention is made feasible at long context with low latency.
Autoregressive attention with KV cache:
During decoding, you form the new query and attend against cached past keys/values:
where are stored, not recomputed. The KV cache is essential—without it, decoding would be catastrophically slow.
Grouped-Query Attention (GQA):
Multiple query heads share fewer KV heads. Let query head map to KV-head group :
This interpolates between Multi-Head Attention (, no sharing) and Multi-Query Attention (, max sharing).
KV cache memory scaling:
Per layer, KV cache grows with:
where is context length. GQA reduces KV memory by $(H_{kv}/H_q)$ relative to full MHA.
FlashAttention's key insight: Reorder computation to minimize memory movement. Instead of materializing the full attention matrix in slow GPU memory (HBM), stream the softmax computation through fast on-chip memory (SRAM) via tiling. Same exact attention math, radically different memory behavior.
Interactive Visualization
Why It Matters for Modern Models
- Llama 3 explicitly uses GQA for inference efficiency—this is the production default for modern open LLMs
- Long context makes KV cache the dominant inference cost: memory grows linearly with T, bandwidth becomes the bottleneck
- FlashAttention enabled much longer contexts and faster training by minimizing memory traffic without approximating attention
- NeurIPS 2024 / ICLR 2025 focus heavily on KV cache compression (SnapKV, RazorAttention)—this is the active research frontier
- Inference cost now rivals training cost at scale, making memory-efficient attention essential for deployment
Missing Intuition
What is still poorly explained in textbooks and papers:
- Attention is "quadratic" on paper, but the real enemy is memory movement—FlashAttention reorders computation to keep data on-chip
- KV cache is the dominant inference memory budget at long context; it can exceed model weights in footprint
- GQA is about bandwidth + cache size, not "making attention cheaper"—you still compute attention, but read/store less
- Head specialization meets systems optimization: KV compression methods exploit that only a subset of heads behave like global "retrieval" heads