diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 8c03c5854e..2bd0e28ab2 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os import textwrap import warnings @@ -588,7 +589,8 @@ def training_step( return loss.detach() / self.args.gradient_accumulation_steps # Same as Trainer._maybe_log_save_evaluate but log our metrics - def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time): + # start_time defaults to None to allow compatibility with transformers<=4.46 + def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None): if self.control.should_log and self.state.global_step > self._globalstep_last_logged: logs: Dict[str, float] = {} @@ -612,7 +614,10 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno self._globalstep_last_logged = self.state.global_step self.store_flos() - self.log(logs, start_time) + if "start_time" in inspect.signature(self.log).parameters: # transformers>=4.47 + self.log(logs, start_time) + else: # transformers<=4.46 + self.log(logs) metrics = None if self.control.should_evaluate: