Variance Reduction in TD Learning via Truncated Importance Sampling

Advanced RL Theory, Planning & TD Learning DS practice problem on Onlearn.

Difficulty: medium.

Topics: Variance Reduction in TD Learning via Truncated Importance Sampling, Truncated Importance Weights, Bias-Variance Tradeoff, Retrace Algorithm, V-trace Operator, Likelihood Ratio Gradient, Reinforcement Learning, Probability Theory, Statistical Estimation, Stochastic Optimization, Information Theory, Temporal Difference Learning, Importance Sampling, Variance Reduction Techniques, Off-Policy Evaluation, Monte Carlo Methods.

Implement a function that performs off policy TD(0) prediction using two strategies and measures the variance of their update signals: 1. Standard IS TD(0) : Uses the full importance sampling ratio to correct for the policy mismatch. 2. Truncated IS TD(0) : Clips the importance sampling ratio to a maximum value c bar to reduce variance at the cost of introducing some bias. For each method, maintain separate value estimates initialized to zero. Process all episodes sequentially, updating values step by step. At each step, record the weighted TD error used for the update. After processing all episodes, compute the population variance (using numpy's default np.var) of the collected weighted TD errors for each state under each method. Parameters: episodes: List of episodes. Each episode is a list of (state, action, reward) tuples. behavior policy: 2D list of shape (num states, num actions) giving b(a|s). target policy: 2D list of shape (num states, num actions) giving pi(a|s). num states: Number of states (0 indexed). num actions: Number of actions. alpha: Learning rate. gamma: Discount factor. c bar: Maximum allowed importance sampling ratio for the truncated method. For terminal transitions (last step in an episode), the next state value is 0. Returns: A dictionary with four keys, each mapping to a numpy array of shape (num states,) rounded to 4 decimal places: "V standard": Value estimates from standard IS TD(0) "V truncated": Value estimates from truncated IS TD(0) "var standard": Per state variance of weighted TD errors from standard method "var truncated": Per state variance of weighted TD errors from truncated method For states never visited, both value and variance remain 0.0.