Skip to content

Commit

Permalink
fix: silly mistake (#82)
Browse files Browse the repository at this point in the history
* fix: a very stupid bug

* chore: refac
  • Loading branch information
lxuechen authored Dec 3, 2023
1 parent 6f53082 commit f8b5fb1
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/alpaca_farm/rl/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ def compute_loss(self, model, inputs, return_outputs=False):
with torch.no_grad():
ref_logits_w = self.ref_model(input_ids=input_ids_w, attention_mask=attention_mask_w).logits[..., :-1, :]
ref_logits_l = self.ref_model(input_ids=input_ids_l, attention_mask=attention_mask_l).logits[..., :-1, :]
ref_logprobs_w = F.cross_entropy(ref_logits_w.transpose(-1, -2), labels_w, reduction="none").sum(-1)
ref_logprobs_l = F.cross_entropy(ref_logits_l.transpose(-1, -2), labels_l, reduction="none").sum(-1)
ref_logprobs_w = -F.cross_entropy(ref_logits_w.transpose(-1, -2), labels_w, reduction="none").sum(-1)
ref_logprobs_l = -F.cross_entropy(ref_logits_l.transpose(-1, -2), labels_l, reduction="none").sum(-1)

logits_w = model(input_ids=input_ids_w, attention_mask=attention_mask_w).logits[..., :-1, :]
logits_l = model(input_ids=input_ids_l, attention_mask=attention_mask_l).logits[..., :-1, :]
logprobs_w = F.cross_entropy(logits_w.transpose(-1, -2), labels_w, reduction="none").sum(-1)
logprobs_l = F.cross_entropy(logits_l.transpose(-1, -2), labels_l, reduction="none").sum(-1)
logprobs_w = -F.cross_entropy(logits_w.transpose(-1, -2), labels_w, reduction="none").sum(-1)
logprobs_l = -F.cross_entropy(logits_l.transpose(-1, -2), labels_l, reduction="none").sum(-1)

logits = self.args.beta * ((logprobs_w - ref_logprobs_w) - (logprobs_l - ref_logprobs_l))
loss = -F.logsigmoid(logits).mean(0)

return (loss, dict(logits=logits)) if return_outputs else loss

0 comments on commit f8b5fb1

Please sign in to comment.