Mixed Precision Training
Backpropagation, Training & Optimization DS practice problem on Onlearn.
Difficulty: medium.
Topics: Understanding Mixed Precision Training, Loss Scaling, Half-Precision Format, Gradient Overflow Detection, Accumulation Precision, Type Casting, Numerical Analysis, Deep Learning Optimization, Computer Architecture, Computational Linear Algebra, Software Engineering for ML, Floating Point Arithmetic, Gradient-Based Optimization, Memory Management, Automatic Differentiation, Hardware Acceleration.
Write a Python class to implement Mixed Precision Training that uses both float32 and float16 data types to optimize memory usage and speed. Your class should have an init (self, loss scale=1024.0) method to initialize with loss scaling factor. Implement forward(self, weights, inputs, targets) to perform forward pass with float16 computation and return Mean Squared Error (MSE) loss (scaled) in float32, and backward(self, gradients) to unscale gradients and check for overflow. Use float16 for computations but float32 for gradient accumulation. Return gradients as float32 and set them to zero if overflow is detected. Only use NumPy.