From 7e1eff7600085814eac65876d4d8a0e38c2f6ccc Mon Sep 17 00:00:00 2001 From: Plemeur <37846989+Plemeur@users.noreply.github.com> Date: Wed, 8 Nov 2023 06:40:00 +0900 Subject: [PATCH] Allow scheduler parameters (#26480) * Allow for scheduler kwargs * Formatting * Arguments checks, passing the tests * Black failed somehow --------- Co-authored-by: Pierre --- src/transformers/optimization.py | 14 +++++++++++++- src/transformers/trainer.py | 1 + src/transformers/training_args.py | 10 ++++++++++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index 5734b6e9cd5846..124813b22abbc5 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -337,6 +337,7 @@ def get_scheduler( optimizer: Optimizer, num_warmup_steps: Optional[int] = None, num_training_steps: Optional[int] = None, + scheduler_specific_kwargs: Optional[dict] = None, ): """ Unified API to get any scheduler from its name. @@ -352,6 +353,9 @@ def get_scheduler( num_training_steps (`int``, *optional*): The number of training steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. + scheduler_specific_kwargs (`dict`, *optional*): + Extra parameters for schedulers such as cosine with restarts. Mismatched scheduler types and scheduler + parameters will cause the scheduler function to raise a TypeError. """ name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] @@ -372,7 +376,15 @@ def get_scheduler( if num_training_steps is None: raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) + if scheduler_specific_kwargs is None: + scheduler_specific_kwargs = {} + + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **scheduler_specific_kwargs, + ) class AdamW(Optimizer): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 7af45324926970..40159d816348c9 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1137,6 +1137,7 @@ def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optim optimizer=self.optimizer if optimizer is None else optimizer, num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_training_steps=num_training_steps, + **self.args.lr_scheduler_kwargs, ) self._created_lr_scheduler = True return self.lr_scheduler diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index aaedc83528a9c0..7ec6e56ff50551 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -238,6 +238,8 @@ class TrainingArguments: when all data is exhausted lr_scheduler_type (`str` or [`SchedulerType`], *optional*, defaults to `"linear"`): The scheduler type to use. See the documentation of [`SchedulerType`] for all possible values. + lr_scheduler_kwargs ('dict', *optional*, defaults to {}): + The extra arguments for the lr_scheduler. See the documentation of each scheduler for possible values. warmup_ratio (`float`, *optional*, defaults to 0.0): Ratio of total training steps used for a linear warmup from 0 to `learning_rate`. warmup_steps (`int`, *optional*, defaults to 0): @@ -729,6 +731,14 @@ class TrainingArguments: default="linear", metadata={"help": "The scheduler type to use."}, ) + lr_scheduler_kwargs: Optional[Dict] = field( + default_factory=dict, + metadata={ + "help": ( + "Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts" + ) + }, + ) warmup_ratio: float = field( default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."} )