Kernel Fusion Memory Savings Calculator
Infrastructure, Parallelism & Hardware Efficiency DS practice problem on Onlearn.
Difficulty: medium.
Topics: Kernel Fusion Memory Savings Calculator, Global Memory Access Reduction, Register Pressure Analysis, Shared Memory Buffering, Kernel Boundary Elimination, Write-Back Minimization, Compiler Optimization, Memory Hierarchy Management, Parallel Computing, Computational Graph Analysis, Hardware Architecture, Operator Fusion, Tiling and Blocking, Kernel Launch Overhead, Data Locality Optimization, Intermediate Representation (IR) Transformation.
Implement a function kernel fusion savings that calculates the memory traffic savings achieved by fusing a chain of GPU kernel operations compared to executing them separately. In GPU computing, each kernel launch reads inputs from and writes outputs to global memory (DRAM). When multiple operations are executed as separate kernels, intermediate results must be written out and read back, creating unnecessary memory traffic. Kernel fusion combines multiple operations into a single kernel, keeping intermediate results in fast on chip registers or shared memory. Function Inputs input elements (int): Number of elements in the initial input tensor to the first operation operations (list of dicts): A list representing the chain of operations, where each dict has: 'output elements' (int): Number of elements in the output tensor of this operation 'extra param elements' (int, optional): Number of additional parameter elements this operation reads (e.g., bias vectors, scale factors). Defaults to 0 if not provided. dtype bytes (int): Number of bytes per element (default: 4, i.e., float32) memory bandwidth gbps (float or None): GPU memory bandwidth in GB/s. If provided, include timing and speedup estimates in the output. Important Notes Operations form a sequential chain: the output of operation i becomes the main input to operation i+1 Each operation may also read additional parameters beyond its chained input Memory traffic is measured in bytes (elements multiplied by dtype bytes) When bandwidth is provided, compute time as total bytes divided by bandwidth, converted to milliseconds Function Output Return a dictionary with: 'unfused memory bytes': Total global memory traffic (reads + writes) when all operations run as separate kernels 'fused memory bytes': Total global memory traffic when all operations are fused into one kernel 'memory saved bytes': Difference between unfused and fused traffic 'savings percent': Percentage of memory traffic saved (rounded to 2 decimal places) If memory bandwidth gbps is provided, also include: 'unfused time ms': Estimated unfused execution time in milliseconds (rounded to 4 decimal places) 'fused time ms': Estimated fused execution time in milliseconds (rounded to 4 decimal places) 'speedup': Ratio of unfused to fused memory traffic (rounded to 4 decimal places)