Skip to content

Commit

Permalink
⏰ Add start_time to _maybe_log_save_evaluate (#2373)
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Nov 20, 2024
1 parent 5626806 commit a0066f4
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import transformers
from datasets import Dataset
from packaging import version
from torch.utils.data import DataLoader, IterableDataset
Expand Down Expand Up @@ -587,8 +588,9 @@ def training_step(

return loss.detach() / self.args.gradient_accumulation_steps

# Same as Trainer.evaluate but log our metrics
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
# Same as Trainer._maybe_log_save_evaluate but log our metrics
# start_time defaults to None to allow compatibility with transformers<=4.46
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
logs: Dict[str, float] = {}

Expand All @@ -612,7 +614,10 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno
self._globalstep_last_logged = self.state.global_step
self.store_flos()

self.log(logs)
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
self.log(logs, start_time)
else: # transformers<=4.46
self.log(logs)

metrics = None
if self.control.should_evaluate:
Expand Down

0 comments on commit a0066f4

Please sign in to comment.