19Efficiency

Efficient Attention at Scale: KV Cache, GQA & FlashAttention

Canonical Papers

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Dao, Fu, Ermon, Rudra, Ré2022NeurIPS
Read paper →

GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

Ainslie et al.2023EMNLP
Read paper →

SnapKV: LLM Knows What You Are Looking for Before Generation

Li et al.2024NeurIPS
Read paper →

Core Mathematics

Modern LLMs are defined not just by *what* attention is, but by how attention is made feasible at long context with low latency.

Autoregressive attention with KV cache:

During decoding, you form the new query qtq_t and attend against cached past keys/values:

ot=softmax ⁣(qtK1:tdk)V1:to_t = \mathrm{softmax}\!\left(\frac{q_t K_{1:t}^{\top}}{\sqrt{d_k}}\right) V_{1:t}

where (K1:t,V1:t)(K_{1:t}, V_{1:t}) are stored, not recomputed. The KV cache is essential—without it, decoding would be catastrophically slow.

Grouped-Query Attention (GQA):

Multiple query heads share fewer KV heads. Let query head h{1,,Hq}h \in \{1,\dots,H_q\} map to KV-head group g(h){1,,Hkv}g(h) \in \{1,\dots,H_{kv}\}:

oh=softmax ⁣(QhKg(h)dk)Vg(h)o_h = \mathrm{softmax}\!\left(\frac{Q_h K_{g(h)}^{\top}}{\sqrt{d_k}}\right) V_{g(h)}

This interpolates between Multi-Head Attention (Hkv=HqH_{kv} = H_q, no sharing) and Multi-Query Attention (Hkv=1H_{kv} = 1, max sharing).

KV cache memory scaling:

Per layer, KV cache grows with:

MemKVTHkvdhead2bytes\text{Mem}_{KV} \propto T \cdot H_{kv} \cdot d_{\text{head}} \cdot 2 \cdot \text{bytes}

where TT is context length. GQA reduces KV memory by $(H_{kv}/H_q)$ relative to full MHA.

FlashAttention's key insight: Reorder computation to minimize memory movement. Instead of materializing the full (T×T)(T \times T) attention matrix in slow GPU memory (HBM), stream the softmax computation through fast on-chip memory (SRAM) via tiling. Same exact attention math, radically different memory behavior.

Key Equation
ot=softmax ⁣(qtK1:tdk)V1:to_t = \mathrm{softmax}\!\left(\frac{q_t K_{1:t}^{\top}}{\sqrt{d_k}}\right) V_{1:t}

Interactive Visualization

Why It Matters for Modern Models

  • Llama 3 explicitly uses GQA for inference efficiency—this is the production default for modern open LLMs
  • Long context makes KV cache the dominant inference cost: memory grows linearly with T, bandwidth becomes the bottleneck
  • FlashAttention enabled much longer contexts and faster training by minimizing memory traffic without approximating attention
  • NeurIPS 2024 / ICLR 2025 focus heavily on KV cache compression (SnapKV, RazorAttention)—this is the active research frontier
  • Inference cost now rivals training cost at scale, making memory-efficient attention essential for deployment

Missing Intuition

What is still poorly explained in textbooks and papers:

  • Attention is "quadratic" on paper, but the real enemy is memory movement—FlashAttention reorders computation to keep data on-chip
  • KV cache is the dominant inference memory budget at long context; it can exceed model weights in footprint
  • GQA is about bandwidth + cache size, not "making attention cheaper"—you still compute attention, but read/store less
  • Head specialization meets systems optimization: KV compression methods exploit that only a subset of heads behave like global "retrieval" heads

Connections

Next Moves

Explore this concept from different angles — like a mathematician would.