Skip to content

Commit

Permalink
🧮 Fix the computation of KL divergence loss (#2277)
Browse files Browse the repository at this point in the history
  • Loading branch information
d-tiapkin authored Oct 25, 2024
1 parent 110d088 commit ea7a1be
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions trl/trainer/nash_md_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,20 +297,19 @@ def _compute_losses(
ref_logprobs_model_data,
probability,
):
# Compute log probs
model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)

# reinforce score where 0.5 is a control variate
score = (probability - 0.5) * model_logprobs_model_data_sum
score = (probability - 0.5) * model_logprobs_model_data.sum(1)

# kl divergence
kl_div = model_logprobs_model_data_sum - ref_logprobs_model_data_sum
# kl divergence via reinforce
with torch.no_grad():
log_ratio = model_logprobs_model_data - ref_logprobs_model_data
kl_div_log = log_ratio.sum(1)
kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1)

# final loss
loss = self.beta * kl_div - score
loss = self.beta * kl_div_loss - score

return loss.mean(), score, kl_div
return loss.mean(), score, kl_div_log

def _log_statistics(
self,
Expand Down

0 comments on commit ea7a1be

Please sign in to comment.