Flash Attention v1 - Forward Pass

LLM Inference & Memory Systems DS practice problem on Onlearn.

Difficulty: hard.

Topics: Understanding Tiling and Recomputation in Flash Attention v1, Softmax Scaling Factor, Row-wise Max Tracking, Accumulated Normalization Constant, Block-level GEMM, Log-Sum-Exp Trick, Deep Learning Optimization, Memory Hierarchy, Attention Mechanisms, High-Performance Computing, Numerical Stability, Tiling and Blocking, Online Softmax, SRAM vs HBM Latency, Kernel Fusion, Matrix Multiplication Optimization.

Implement a simplified version of the Flash Attention v1 forward pass. Given query (Q), key (K), and value (V) matrices, compute the attention output using block wise tiling. Assume a block size of 2 for simplicity. Your implementation must perform the online softmax update (m i, l i) for each block to demonstrate the memory efficient approach.