Speculative Decoding Verification
LLM Inference & Memory Systems DS practice problem on Onlearn.
Difficulty: hard.
Topics: Speculative Decoding Verification, Draft Model Approximation, KV Cache Rejection Sampling, Speculative Token Acceptance, Memory Bandwidth Bottlenecking, Logit Distribution Alignment, Natural Language Processing, Distributed Systems, Computational Complexity, Computer Architecture, Probability and Statistics, Autoregressive Decoding, Memory Hierarchy Management, Parallel Computing Paradigms, Probabilistic Model Calibration, Latency Optimization Techniques.
Speculative decoding is an inference time optimization for large language models. A smaller, faster "draft" model proposes K candidate tokens, and a larger "target" model verifies them in a single forward pass. The verification step decides which draft tokens to accept, preserving the target model's output distribution. Implement the verification function for speculative decoding. Given K draft tokens along with the full probability distributions from both the draft and target models at each position, determine which tokens to accept. The function should process each draft token sequentially: If a token is accepted, move to the next position. If a token is rejected, compute an adjusted distribution at that position, resample a replacement token from it, and stop (discard all remaining draft tokens). If all K tokens are accepted, return them all. Your function receives: draft tokens: list of K token indices proposed by the draft model draft probs: K x V array of draft model probability distributions target probs: K x V array of target model probability distributions coin flips: list of K random floats in [0, 1) for accept/reject decisions resample coin: a single float in [0, 1) used to resample from the adjusted distribution upon rejection Return: a list of final token indices (accepted draft tokens, plus one resampled token if rejection occurred).