Gradient Temporal Difference (GTD2) for Off-Policy Value Prediction

Representation Learning, Advanced Theory & Miscellaneous DS practice problem on Onlearn.

Difficulty: medium.

Topics: Gradient Temporal Difference (GTD2) for Off-Policy Value Prediction, Importance Sampling, Bellman Error, Fixed-Point Iteration, Semi-Gradient Descent, Feature Vector Representation, Reinforcement Learning, Stochastic Optimization, Linear Algebra, Statistical Learning Theory, Dynamical Systems, Temporal Difference Learning, Off-Policy Evaluation, Gradient-Based Methods, Function Approximation, Projection Operators.

Implement the GTD2 algorithm (also known as the gradient correction method for temporal difference learning) for policy evaluation with linear function approximation. GTD2 is a two timescale stochastic gradient descent method that minimizes the Mean Squared Projected Bellman Error (MSPBE). It maintains two weight vectors: Primary weights w: Used for the value function approximation V(s) = w^T phi(s) Secondary weights v: An auxiliary vector that corrects the update direction to follow the true gradient of the MSPBE objective The algorithm processes transitions one at a time. For each transition (s, r, s next, done): 1. Compute the feature vectors x = features[s] and x next = features[s next]. If the transition is terminal (done=True), set x next to the zero vector. 2. Compute the TD error using the current primary weights. 3. Update the primary weights w using the secondary weights v to form a corrected gradient direction. 4. Update the secondary weights v to track the TD error projected onto the feature space. The key distinction of GTD2 from standard TD is that the primary weight update depends on v rather than directly on the TD error, which ensures convergence guarantees even under off policy sampling and function approximation. Write a function gtd2 prediction that processes all transitions sequentially and returns the final weight vectors. Inputs: features: numpy array of shape (num states, d), where features[s] is the feature vector for state s transitions: list of tuples (s, r, s next, done) representing observed transitions gamma: discount factor (float) alpha w: learning rate for primary weights (float) alpha v: learning rate for secondary weights (float) w init: initial primary weights, numpy array of shape (d,) v init: initial secondary weights, numpy array of shape (d,) Returns: A tuple (w, v) where each is a list of floats rounded to 4 decimal places.