MoE with Shared Expert Forward Pass
MoE, Compression & Scaling DS practice problem on Onlearn.
Difficulty: medium.
Topics: MoE with Shared Expert Forward Pass, Top-k Gating Mechanism, Expert Capacity Constraint, All-to-All Communication, Load Balancing Loss, Shared Parameter Quantization, Deep Learning Architectures, Distributed Systems, Numerical Linear Algebra, Optimization Theory, Computational Complexity, Mixture of Experts, Parallel Computing Paradigms, Sparse Matrix Operations, Gradient-based Learning, Memory Hierarchy Management.
Implement the forward pass of a Mixture of Experts (MoE) layer that includes a shared expert alongside routed experts. This architecture is used in models like DeepSeek V2/V3 where one expert always processes every token to maintain shared knowledge, while additional routed experts specialize on different inputs. Your function takes: X: Input token representations of shape (num tokens, d model) W gate: Gating weight matrix of shape (d model, num routed experts) for computing routing scores W shared: Shared expert weight matrix of shape (d model, d out) W experts: List of routed expert weight matrices, each of shape (d model, d out) top k: Number of routed experts to activate per token (default 2) Your function should return a dictionary with the following keys: 'output': Final combined output of shape (num tokens, d out) 'shared output': Output from the shared expert of shape (num tokens, d out) 'routed output': Aggregated output from the routed experts of shape (num tokens, d out) 'routing indices': The selected top k expert indices per token of shape (num tokens, top k), sorted in descending order of gating score 'routing weights': The renormalized gating weights for the selected experts of shape (num tokens, top k) The gating mechanism should use softmax over all routed experts to produce scores, select the top k experts per token, renormalize the selected scores to sum to 1, and compute a weighted combination of the selected expert outputs. The shared expert output is added directly to the routed output. Each expert (shared and routed) performs a simple linear transformation: output = input @ W expert.