Context Parallelism with Ring Attention for Video Models

Infrastructure, Parallelism & Hardware Efficiency DS practice problem on Onlearn.

Difficulty: hard.

Topics: Context Parallelism with Ring Attention for Video Models, Ring-based All-Gather, FlashAttention-2, KV Cache Quantization, Pipeline Bubble Reduction, Block-wise Causal Masking, Distributed Systems, Deep Learning Architectures, High-Performance Computing, Memory Management, Parallel Computing Paradigms, Sequence Parallelism, Attention Mechanisms, Communication Overlap, Tensor Partitioning, Kernel Optimization.

Implement a simulation of Ring Attention used for Context Parallelism in video diffusion and long context transformer models. In video generation and understanding, sequence lengths can be enormous (thousands of frame tokens). Context Parallelism distributes the sequence across multiple devices. Ring Attention is the communication pattern used: each device holds its own query chunk, and key value (KV) blocks are passed around in a ring topology so that every device eventually attends to the entire sequence. Your task is to implement ring attention simulate(Q, K, V, num devices) that: 1. Splits the input sequence (of length seq len) into num devices equal chunks along the sequence dimension. Device i owns rows [i C : (i+1) C] where C = seq len // num devices. 2. Simulates num devices ring steps. At each step, each device computes partial attention between its local query chunk and the currently available KV chunk, then the KV blocks rotate to the next device in the ring. 3. Accumulates partial attention results across steps using the online softmax trick for numerically stable incremental attention computation across distributed KV blocks. 4. Returns a tuple of: output: A numpy array of shape (seq len, d) with the final attention output (rounded to 4 decimal places). This should be mathematically equivalent to standard scaled dot product attention. comm schedule: A list of num devices lists, where comm schedule[step][device] is the index of the source device whose KV block device device processes at that step. Assume seq len is always divisible by num devices. Use scaled dot product attention with scaling factor 1 / sqrt(d) where d is the feature dimension.