On-Policy Trajectory Sampling
Representation Learning, Advanced Theory & Miscellaneous DS practice problem on Onlearn.
Difficulty: medium.
Topics: On-Policy Trajectory Sampling, Generalized Advantage Estimation, Log-Derivative Trick, Trajectory Rollout, Baseline Subtraction, Policy Entropy Regularization, Reinforcement Learning, Stochastic Processes, Information Theory, Control Theory, Statistical Inference, Policy Gradient Methods, Markov Decision Processes, Monte Carlo Estimation, Variance Reduction Techniques, Importance Sampling.
Implement a function on policy sample that generates trajectories by following a given stochastic policy in an MDP, and computes key on policy statistics: the empirical state visitation frequency and state action visit counts. The state visitation frequency is a normalized distribution over states that reflects how often the policy visits each state. This distribution is fundamental in policy optimization, as it determines which states contribute most to gradient updates. Parameters: P: Numpy array of shape (n states, n actions, n states) transition probabilities. P[s, a, s'] is the probability of reaching s' from s under action a. R: Numpy array of shape (n states, n actions, n states) rewards. R[s, a, s'] is the reward for transition (s, a, s'). policy: Numpy array of shape (n states, n actions) stochastic policy. policy[s, a] is the probability of choosing action a in state s. start state: Integer, the starting state for every episode. terminal states: List of integers indicating terminal states. When the agent enters a terminal state, the episode ends immediately with no further actions. gamma: Float, discount factor in [0, 1]. num episodes: Integer, number of episodes to sample. max steps: Integer, maximum steps per episode. seed: Integer, random seed for reproducibility. Returns: A tuple (state visitation freq, state action counts, average return) where: state visitation freq: A list of floats (length n states) representing the fraction of total visits that went to each state, rounded to 4 decimal places. If no states were visited, return all zeros. state action counts: A list of lists of integers representing raw visit counts for each (state, action) pair. average return: Float, the mean discounted return across all episodes, rounded to 4 decimal places. Set the random seed using np.random.seed(seed) once at the start. Use np.random.choice for all sampling. Only count visits to non terminal states where the agent actually takes an action.