Skip to content

Commit

Permalink
Allow scheduler parameters (huggingface#26480)
Browse files Browse the repository at this point in the history
* Allow for scheduler kwargs

* Formatting

* Arguments checks, passing the tests

* Black failed somehow

---------

Co-authored-by: Pierre <[email protected]>
  • Loading branch information
Plemeur and Pierre authored Nov 7, 2023
1 parent ac5d4cf commit 7e1eff7
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/transformers/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def get_scheduler(
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
scheduler_specific_kwargs: Optional[dict] = None,
):
"""
Unified API to get any scheduler from its name.
Expand All @@ -352,6 +353,9 @@ def get_scheduler(
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
scheduler_specific_kwargs (`dict`, *optional*):
Extra parameters for schedulers such as cosine with restarts. Mismatched scheduler types and scheduler
parameters will cause the scheduler function to raise a TypeError.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
Expand All @@ -372,7 +376,15 @@ def get_scheduler(
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")

return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
if scheduler_specific_kwargs is None:
scheduler_specific_kwargs = {}

return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
**scheduler_specific_kwargs,
)


class AdamW(Optimizer):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,7 @@ def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optim
optimizer=self.optimizer if optimizer is None else optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
**self.args.lr_scheduler_kwargs,
)
self._created_lr_scheduler = True
return self.lr_scheduler
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ class TrainingArguments:
when all data is exhausted
lr_scheduler_type (`str` or [`SchedulerType`], *optional*, defaults to `"linear"`):
The scheduler type to use. See the documentation of [`SchedulerType`] for all possible values.
lr_scheduler_kwargs ('dict', *optional*, defaults to {}):
The extra arguments for the lr_scheduler. See the documentation of each scheduler for possible values.
warmup_ratio (`float`, *optional*, defaults to 0.0):
Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.
warmup_steps (`int`, *optional*, defaults to 0):
Expand Down Expand Up @@ -729,6 +731,14 @@ class TrainingArguments:
default="linear",
metadata={"help": "The scheduler type to use."},
)
lr_scheduler_kwargs: Optional[Dict] = field(
default_factory=dict,
metadata={
"help": (
"Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts"
)
},
)
warmup_ratio: float = field(
default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}
)
Expand Down

0 comments on commit 7e1eff7

Please sign in to comment.