Branching Factor Impact on Sample Backups
Representation Learning, Advanced Theory & Miscellaneous DS practice problem on Onlearn.
Difficulty: medium.
Topics: Branching Factor Impact on Sample Backups, Branching Factor, Backup Operator, Lookahead Depth, Discount Factor, Temporal Difference Error, Reinforcement Learning, Computational Complexity, Information Theory, Probabilistic Graphical Models, Dynamic Programming, Monte Carlo Tree Search, Bellman Equations, State Space Search, Value Function Approximation, Markov Decision Processes.
Implement a function that analyzes how the branching factor (number of possible successor states) of state action pairs in a Markov Decision Process affects the accuracy of sample based value backups compared to full expected backups. In reinforcement learning, a full (expected) backup computes the exact action value by summing over all possible next states weighted by their transition probabilities: Q(s, a) = sum over all s' of p(s'|s,a) [r(s,a,s') + gamma V(s')] A sample backup draws a single successor state according to the transition distribution and computes r + gamma V(s') from that one sample. Sample backups are computationally cheaper (O(1) vs O(b) where b is the branching factor), but introduce variance. Your function should: 1. For each state action pair in the MDP, determine its branching factor (number of possible next states) 2. Compute the exact Q value using the full expected backup 3. Perform n samples independent sample backups, drawing successor states according to the transition probabilities 4. Compute the root mean squared error (RMSE) of the sample backup values relative to the exact Q value 5. Group results by branching factor and return the average RMSE for each branching factor Process state action pairs in sorted order (by state index, then action index). Use numpy.random.RandomState for reproducibility. Args: mdp: dict mapping (state, action) tuples to lists of (next state, probability, reward) tuples values: list of float, current value estimates V(s) for each state gamma: float, discount factor n samples: int, number of sample backups per state action pair seed: int, random seed Returns: dict mapping branching factor (int) to average RMSE (float), rounded to 4 decimal places, with keys in ascending order