Attention Sink Detection

Attention Mechanisms DS practice problem on Onlearn.

Difficulty: medium.

Topics: Attention Sink Detection, Attention Sink Phenomenon, Softmax Temperature Scaling, KV Cache Eviction, Initial Token Embedding, Attention Score Sparsity, Natural Language Processing, Deep Learning Architectures, Information Theory, Computational Complexity, Signal Processing, Transformer Attention Mechanisms, Sequence Modeling, Activation Dynamics, Model Interpretability, Inference Optimization.

In transformer models, certain token positions consistently absorb disproportionately high attention weights across query positions and attention heads, regardless of their semantic content. These positions are called attention sinks and are typically observed at the beginning of sequence (BOS) token or early positions. This phenomenon was identified in the StreamingLLM research and has important implications for efficient long context inference. When the initial tokens are removed from a KV cache, model performance degrades catastrophically not because those tokens carry important semantic information, but because they serve as stable "dump" positions for excess attention probability mass. Implement a function detect attention sinks that takes a 3D attention weight array of shape (num heads, seq len, seq len) and a threshold value, then identifies which token positions are attention sinks. The attention weight array attn weights[h, i, j] represents how much query position i attends to key position j in head h. Each row (across the key dimension) sums to 1.0 as it represents a softmax distribution. A token position is considered an attention sink if the average attention it receives (averaged across all heads and all query positions) meets or exceeds the given threshold. The function should return a dictionary with three keys: 'sink positions': a sorted list of integer positions identified as sinks 'avg attention received': a list of floats (rounded to 4 decimal places) showing the average attention received by each position 'sink scores': a list of floats (rounded to 4 decimal places) containing the average attention values for only the sink positions, in order