Direct Preference Optimization (DPO) Loss
RLHF, Reward Modeling & Human Feedback DS practice problem on Onlearn.
Difficulty: medium.
Topics: Direct Preference Optimization (DPO) Loss, Bradley-Terry Model, Log-Sigmoid Likelihood, KL-Divergence Regularization, Implicit Reward Mapping, Policy-to-Reward Inversion, Statistical Learning Theory, Optimization Theory, Natural Language Processing, Reinforcement Learning, Information Theory, Preference Modeling, Policy Gradient Methods, Supervised Fine-Tuning, Objective Function Design, Probability Distribution Estimation.
Implement the Direct Preference Optimization (DPO) loss function used in aligning large language models with human preferences. DPO is a method that directly optimizes a language model policy using preference pairs (chosen vs rejected responses) without requiring a separate reward model. It leverages the log probabilities of responses under both the current policy model and a frozen reference model. Your function should take: log probs chosen policy: Log probabilities of preferred responses under the policy model log probs rejected policy: Log probabilities of dispreferred responses under the policy model log probs chosen ref: Log probabilities of preferred responses under the reference model log probs rejected ref: Log probabilities of dispreferred responses under the reference model beta: A temperature parameter controlling the strength of the KL constraint Your function should return a dictionary with: 'loss': The average DPO loss across the batch, rounded to 4 decimal places 'chosen rewards': A list of implicit reward values for the chosen responses, rounded to 4 decimal places 'rejected rewards': A list of implicit reward values for the rejected responses, rounded to 4 decimal places The implicit reward for a response is defined as the scaled log ratio between the policy and reference probabilities. The DPO loss encourages the policy to assign higher implicit reward to chosen responses compared to rejected ones. All inputs are lists of floats with the same length (batch size).