Multi-Head Latent Attention (MLA)
Vision-Language & Cross-Modal Systems DS practice problem on Onlearn.
Difficulty: hard.
Topics: Multi-Head Latent Attention (MLA), Low-Rank Projection, Query-Key-Value Compression, Rotary Positional Embeddings, KV Cache Quantization, Latent Bottleneck Mapping, Linear Algebra, Deep Learning Architectures, Information Theory, Optimization Theory, Multimodal Representation Learning, Matrix Factorization, Attention Mechanisms, Dimensionality Reduction, Parameter Efficiency, Cross-Modal Alignment.
Implement Multi Head Latent Attention (MLA), an efficient attention mechanism that compresses key value representations into a low dimensional latent space before reconstructing them for attention computation. This technique, used in architectures like DeepSeek V2, dramatically reduces the KV cache memory during inference. Your function should: 1. Compress the input into a low rank KV latent representation using a down projection matrix 2. Reconstruct keys and values from the compressed latent using separate up projection matrices 3. Similarly compress and reconstruct queries using their own down/up projection matrices 4. Perform multi head scaled dot product attention with the reconstructed Q, K, V 5. Apply an output projection to the concatenated attention heads 6. Return both the final output and the compressed KV latent (to illustrate the memory savings) The function takes: X: input of shape (seq len, d model) W dkv: KV down projection of shape (d model, d c kv) W uk: key up projection of shape (d c kv, d model) W uv: value up projection of shape (d c kv, d model) W dq: query down projection of shape (d model, d c q) W uq: query up projection of shape (d c q, d model) W o: output projection of shape (d model, d model) n heads: number of attention heads (must evenly divide d model) Use numerically stable softmax (subtract the maximum before exponentiating).