diff --git a/src/alpaca_farm/rl/dpo_trainer.py b/src/alpaca_farm/rl/dpo_trainer.py index eb18d85..ceee239 100644 --- a/src/alpaca_farm/rl/dpo_trainer.py +++ b/src/alpaca_farm/rl/dpo_trainer.py @@ -38,14 +38,15 @@ def compute_loss(self, model, inputs, return_outputs=False): with torch.no_grad(): ref_logits_w = self.ref_model(input_ids=input_ids_w, attention_mask=attention_mask_w).logits[..., :-1, :] ref_logits_l = self.ref_model(input_ids=input_ids_l, attention_mask=attention_mask_l).logits[..., :-1, :] - ref_logprobs_w = F.cross_entropy(ref_logits_w.transpose(-1, -2), labels_w, reduction="none").sum(-1) - ref_logprobs_l = F.cross_entropy(ref_logits_l.transpose(-1, -2), labels_l, reduction="none").sum(-1) + ref_logprobs_w = -F.cross_entropy(ref_logits_w.transpose(-1, -2), labels_w, reduction="none").sum(-1) + ref_logprobs_l = -F.cross_entropy(ref_logits_l.transpose(-1, -2), labels_l, reduction="none").sum(-1) logits_w = model(input_ids=input_ids_w, attention_mask=attention_mask_w).logits[..., :-1, :] logits_l = model(input_ids=input_ids_l, attention_mask=attention_mask_l).logits[..., :-1, :] - logprobs_w = F.cross_entropy(logits_w.transpose(-1, -2), labels_w, reduction="none").sum(-1) - logprobs_l = F.cross_entropy(logits_l.transpose(-1, -2), labels_l, reduction="none").sum(-1) + logprobs_w = -F.cross_entropy(logits_w.transpose(-1, -2), labels_w, reduction="none").sum(-1) + logprobs_l = -F.cross_entropy(logits_l.transpose(-1, -2), labels_l, reduction="none").sum(-1) logits = self.args.beta * ((logprobs_w - ref_logprobs_w) - (logprobs_l - ref_logprobs_l)) loss = -F.logsigmoid(logits).mean(0) + return (loss, dict(logits=logits)) if return_outputs else loss