Skip to content

Commit

Permalink
Add start_time parameter to _maybe_log_save_evaluate method
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Nov 20, 2024
1 parent 7651fc5 commit 2040a00
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import os
import textwrap
import warnings
Expand Down Expand Up @@ -588,7 +589,8 @@ def training_step(
return loss.detach() / self.args.gradient_accumulation_steps

# Same as Trainer._maybe_log_save_evaluate but log our metrics
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time):
# start_time defaults to None to allow compatibility with transformers<=4.46
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
logs: Dict[str, float] = {}

Expand All @@ -612,7 +614,10 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno
self._globalstep_last_logged = self.state.global_step
self.store_flos()

self.log(logs, start_time)
if "start_time" in inspect.signature(self.log).parameters: # transformers>=4.47
self.log(logs, start_time)
else: # transformers<=4.46
self.log(logs)

metrics = None
if self.control.should_evaluate:
Expand Down

0 comments on commit 2040a00

Please sign in to comment.