-
-
Notifications
You must be signed in to change notification settings - Fork 120
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #195 from WenjieDu/lr_scheduler
Add learning-rate schedulers
- Loading branch information
Showing
18 changed files
with
1,053 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,11 +6,12 @@ | |
# Created by Wenjie Du <[email protected]> | ||
# License: GLP-v3 | ||
|
||
from typing import Iterable | ||
from typing import Iterable, Optional | ||
|
||
from torch.optim import Adadelta as torch_Adadelta | ||
|
||
from .base import Optimizer | ||
from .lr_scheduler.base import LRScheduler | ||
|
||
|
||
class Adadelta(Optimizer): | ||
|
@@ -39,8 +40,9 @@ def __init__( | |
rho: float = 0.9, | ||
eps: float = 1e-08, | ||
weight_decay: float = 0.01, | ||
lr_scheduler: Optional[LRScheduler] = None, | ||
): | ||
super().__init__(lr) | ||
super().__init__(lr, lr_scheduler) | ||
self.rho = rho | ||
self.eps = eps | ||
self.weight_decay = weight_decay | ||
|
@@ -61,3 +63,6 @@ def init_optimizer(self, params: Iterable) -> None: | |
eps=self.eps, | ||
weight_decay=self.weight_decay, | ||
) | ||
|
||
if self.lr_scheduler is not None: | ||
self.lr_scheduler.init_scheduler(self.torch_optimizer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,11 +6,12 @@ | |
# Created by Wenjie Du <[email protected]> | ||
# License: GLP-v3 | ||
|
||
from typing import Iterable | ||
from typing import Iterable, Optional | ||
|
||
from torch.optim import Adagrad as torch_Adagrad | ||
|
||
from .base import Optimizer | ||
from .lr_scheduler.base import LRScheduler | ||
|
||
|
||
class Adagrad(Optimizer): | ||
|
@@ -43,8 +44,9 @@ def __init__( | |
weight_decay: float = 0.01, | ||
initial_accumulator_value: float = 0.01, # it is set as 0 in the torch implementation, but delta shouldn't be 0 | ||
eps: float = 1e-08, | ||
lr_scheduler: Optional[LRScheduler] = None, | ||
): | ||
super().__init__(lr) | ||
super().__init__(lr, lr_scheduler) | ||
self.lr_decay = lr_decay | ||
self.weight_decay = weight_decay | ||
self.initial_accumulator_value = initial_accumulator_value | ||
|
@@ -67,3 +69,6 @@ def init_optimizer(self, params: Iterable) -> None: | |
initial_accumulator_value=self.initial_accumulator_value, | ||
eps=self.eps, | ||
) | ||
|
||
if self.lr_scheduler is not None: | ||
self.lr_scheduler.init_scheduler(self.torch_optimizer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,11 +6,12 @@ | |
# Created by Wenjie Du <[email protected]> | ||
# License: GLP-v3 | ||
|
||
from typing import Iterable, Tuple | ||
from typing import Iterable, Tuple, Optional | ||
|
||
from torch.optim import Adam as torch_Adam | ||
|
||
from .base import Optimizer | ||
from .lr_scheduler.base import LRScheduler | ||
|
||
|
||
class Adam(Optimizer): | ||
|
@@ -42,8 +43,9 @@ def __init__( | |
eps: float = 1e-08, | ||
weight_decay: float = 0, | ||
amsgrad: bool = False, | ||
lr_scheduler: Optional[LRScheduler] = None, | ||
): | ||
super().__init__(lr) | ||
super().__init__(lr, lr_scheduler) | ||
self.betas = betas | ||
self.eps = eps | ||
self.weight_decay = weight_decay | ||
|
@@ -66,3 +68,6 @@ def init_optimizer(self, params: Iterable) -> None: | |
weight_decay=self.weight_decay, | ||
amsgrad=self.amsgrad, | ||
) | ||
|
||
if self.lr_scheduler is not None: | ||
self.lr_scheduler.init_scheduler(self.torch_optimizer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,11 +6,12 @@ | |
# Created by Wenjie Du <[email protected]> | ||
# License: GLP-v3 | ||
|
||
from typing import Iterable, Tuple | ||
from typing import Iterable, Tuple, Optional | ||
|
||
from torch.optim import AdamW as torch_AdamW | ||
|
||
from .base import Optimizer | ||
from .lr_scheduler.base import LRScheduler | ||
|
||
|
||
class AdamW(Optimizer): | ||
|
@@ -42,8 +43,9 @@ def __init__( | |
eps: float = 1e-08, | ||
weight_decay: float = 0.01, | ||
amsgrad: bool = False, | ||
lr_scheduler: Optional[LRScheduler] = None, | ||
): | ||
super().__init__(lr) | ||
super().__init__(lr, lr_scheduler) | ||
self.betas = betas | ||
self.eps = eps | ||
self.weight_decay = weight_decay | ||
|
@@ -66,3 +68,6 @@ def init_optimizer(self, params: Iterable) -> None: | |
weight_decay=self.weight_decay, | ||
amsgrad=self.amsgrad, | ||
) | ||
|
||
if self.lr_scheduler is not None: | ||
self.lr_scheduler.init_scheduler(self.torch_optimizer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
""" | ||
Learning rate schedulers available in PyPOTS. Their functionalities are the same with those in PyTorch, | ||
the only difference that is also why we implement them is that you don't have to pass according optimizers | ||
into them immediately while initializing them. Instead, you can pass them into pypots.optim.Optimizer | ||
after initialization and call their `init_scheduler()` method in pypots.optim.Optimizer.init_optimizer() to initialize | ||
schedulers together with optimizers. | ||
""" | ||
|
||
# Created by Wenjie Du <[email protected]> | ||
# License: GLP-v3 | ||
|
||
from .lambda_lrs import LambdaLR | ||
from .multiplicative_lrs import MultiplicativeLR | ||
from .step_lrs import StepLR | ||
from .multistep_lrs import MultiStepLR | ||
from .constant_lrs import ConstantLR | ||
from .exponential_lrs import ExponentialLR | ||
from .linear_lrs import LinearLR | ||
|
||
|
||
__all__ = [ | ||
"LambdaLR", | ||
"MultiplicativeLR", | ||
"StepLR", | ||
"MultiStepLR", | ||
"ConstantLR", | ||
"ExponentialLR", | ||
"LinearLR", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
""" | ||
The base class for learning rate schedulers. This class is adapted from PyTorch, | ||
please refer to torch.optim.lr_scheduler for more details. | ||
""" | ||
|
||
# Created by Wenjie Du <[email protected]> | ||
# License: GLP-v3 | ||
|
||
import weakref | ||
from abc import ABC, abstractmethod | ||
from functools import wraps | ||
|
||
from torch.optim import Optimizer | ||
|
||
from ...utils.logging import logger | ||
|
||
|
||
class LRScheduler(ABC): | ||
"""Base class for PyPOTS learning rate schedulers. | ||
Parameters | ||
---------- | ||
last_epoch: int | ||
The index of last epoch. Default: -1. | ||
verbose: If ``True``, prints a message to stdout for | ||
each update. Default: ``False``. | ||
""" | ||
|
||
def __init__(self, last_epoch=-1, verbose=False): | ||
self.last_epoch = last_epoch | ||
self.verbose = verbose | ||
self.optimizer = None | ||
self.base_lrs = None | ||
self._last_lr = None | ||
self._step_count = 0 | ||
|
||
def init_scheduler(self, optimizer): | ||
"""Initialize the scheduler. This method should be called in pypots.optim.Optimizer.init_optimizer() | ||
to initialize the scheduler together with the optimizer. | ||
Parameters | ||
---------- | ||
optimizer: torch.optim.Optimizer, | ||
The optimizer to be scheduled. | ||
""" | ||
|
||
# Attach optimizer | ||
if not isinstance(optimizer, Optimizer): | ||
raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) | ||
self.optimizer = optimizer | ||
|
||
# Initialize epoch and base learning rates | ||
if self.last_epoch == -1: | ||
for group in optimizer.param_groups: | ||
group.setdefault("initial_lr", group["lr"]) | ||
else: | ||
for i, group in enumerate(optimizer.param_groups): | ||
if "initial_lr" not in group: | ||
raise KeyError( | ||
"param 'initial_lr' is not specified " | ||
"in param_groups[{}] when resuming an optimizer".format(i) | ||
) | ||
self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] | ||
|
||
# Following https://github.com/pytorch/pytorch/issues/20124 | ||
# We would like to ensure that `lr_scheduler.step()` is called after | ||
# `optimizer.step()` | ||
def with_counter(method): | ||
if getattr(method, "_with_counter", False): | ||
# `optimizer.step()` has already been replaced, return. | ||
return method | ||
|
||
# Keep a weak reference to the optimizer instance to prevent | ||
# cyclic references. | ||
instance_ref = weakref.ref(method.__self__) | ||
# Get the unbound method for the same purpose. | ||
func = method.__func__ | ||
cls = instance_ref().__class__ | ||
del method | ||
|
||
@wraps(func) | ||
def wrapper(*args, **kwargs): | ||
instance = instance_ref() | ||
instance._step_count += 1 | ||
wrapped = func.__get__(instance, cls) | ||
return wrapped(*args, **kwargs) | ||
|
||
# Note that the returned function here is no longer a bound method, | ||
# so attributes like `__func__` and `__self__` no longer exist. | ||
wrapper._with_counter = True | ||
return wrapper | ||
|
||
self.optimizer.step = with_counter(self.optimizer.step) | ||
self.optimizer._step_count = 0 | ||
|
||
@abstractmethod | ||
def get_lr(self): | ||
"""Compute learning rate.""" | ||
# Compute learning rate using chainable form of the scheduler | ||
raise NotImplementedError | ||
|
||
def get_last_lr(self): | ||
"""Return last computed learning rate by current scheduler.""" | ||
return self._last_lr | ||
|
||
@staticmethod | ||
def print_lr(is_verbose, group, lr): | ||
"""Display the current learning rate.""" | ||
if is_verbose: | ||
logger.info(f"Adjusting learning rate of group {group} to {lr:.4e}.") | ||
|
||
def step(self): | ||
"""Step could be called after every batch update. This should be called in ``pypots.optim.Optimizer.step()`` | ||
after ``pypots.optim.Optimizer.torch_optimizer.step()``. | ||
""" | ||
# Raise a warning if old pattern is detected | ||
# https://github.com/pytorch/pytorch/issues/20124 | ||
if self._step_count == 1: | ||
if not hasattr(self.optimizer.step, "_with_counter"): | ||
logger.warning( | ||
"Seems like `optimizer.step()` has been overridden after learning rate scheduler " | ||
"initialization. Please, make sure to call `optimizer.step()` before " | ||
"`lr_scheduler.step()`. See more details at " | ||
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", | ||
) | ||
|
||
# Just check if there were two first lr_scheduler.step() calls before optimizer.step() | ||
elif self.optimizer._step_count < 1: | ||
logger.warning.warn( | ||
"Detected call of `lr_scheduler.step()` before `optimizer.step()`. " | ||
"In PyTorch 1.1.0 and later, you should call them in the opposite order: " | ||
"`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " | ||
"will result in PyTorch skipping the first value of the learning rate schedule. " | ||
"See more details at " | ||
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", | ||
) | ||
self._step_count += 1 | ||
|
||
class _enable_get_lr_call: | ||
def __init__(self, o): | ||
self.o = o | ||
|
||
def __enter__(self): | ||
self.o._get_lr_called_within_step = True | ||
return self | ||
|
||
def __exit__(self, type, value, traceback): | ||
self.o._get_lr_called_within_step = False | ||
|
||
with _enable_get_lr_call(self): | ||
self.last_epoch += 1 | ||
values = self.get_lr() | ||
|
||
for i, data in enumerate(zip(self.optimizer.param_groups, values)): | ||
param_group, lr = data | ||
param_group["lr"] = lr | ||
self.print_lr(self.verbose, i, lr) | ||
|
||
self._last_lr = [group["lr"] for group in self.optimizer.param_groups] |
Oops, something went wrong.