Skip to content

Commit

Permalink
Import transformers package and update log method for compatibility w…
Browse files Browse the repository at this point in the history
…ith transformers>=4.47
  • Loading branch information
qgallouedec committed Nov 20, 2024
1 parent 2040a00 commit 5b88750
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import os
import textwrap
import warnings
Expand All @@ -26,6 +25,7 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import transformers
from datasets import Dataset
from packaging import version
from torch.utils.data import DataLoader, IterableDataset
Expand Down Expand Up @@ -614,7 +614,7 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno
self._globalstep_last_logged = self.state.global_step
self.store_flos()

if "start_time" in inspect.signature(self.log).parameters: # transformers>=4.47
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
self.log(logs, start_time)
else: # transformers<=4.46
self.log(logs)
Expand Down

0 comments on commit 5b88750

Please sign in to comment.