Skip to content

Commit

Permalink
Merge pull request #9 from eltociear/patch-1
Browse files Browse the repository at this point in the history
chore: update trainer.py
  • Loading branch information
angelahzyuan authored Jul 4, 2024
2 parents 08c67a5 + b4a2429 commit e524519
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion sppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ def compute_loss(
def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
"""Generate samples from the model and reference model for the given batch of inputs."""

# If one uses `generate_during_eval` with peft + bf16, we need to explictly call generate with
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast

Expand Down

0 comments on commit e524519

Please sign in to comment.