diff --git a/merlin/models/torch/models/base.py b/merlin/models/torch/models/base.py index de2879ebb4..e47c2df4d6 100644 --- a/merlin/models/torch/models/base.py +++ b/merlin/models/torch/models/base.py @@ -13,12 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import inspect +import itertools import os -from typing import Dict, List, Optional, Sequence, Union +from typing import Dict, Iterator, List, Optional, Sequence, Type, Union import torch +from packaging import version from pytorch_lightning import LightningDataModule, LightningModule -from torch import nn +from torch import nn, optim from merlin.dataloader.torch import Loader from merlin.io import Dataset @@ -28,6 +31,16 @@ from merlin.models.torch.utils import module_utils from merlin.models.utils.registry import camelcase_to_snakecase +OptimizerType = Union[optim.Optimizer, Type[optim.Optimizer], str] + +LRScheduler = ( + optim.lr_scheduler._LRScheduler + if version.parse(torch.__version__).major < 2 + else optim.lr_scheduler.LRScheduler +) + +LRSchedulerType = Union[LRScheduler, Type[LRScheduler]] + class Model(LightningModule, Block): """ @@ -43,8 +56,11 @@ class Model(LightningModule, Block): schema: Schema, optional A Merlin schema. Default is None. optimizer: torch.optim.Optimizer, optional - A PyTorch optimizer from the PyTorch library (or any custom optimizer + A PyTorch optimizer instance or class from the PyTorch library (or any custom optimizer that follows the same API). Default is Adam optimizer. + scheduler: torch.optim.lr_scheduler.LRScheduler, optional + A PyTorch learning rate scheduler instance from the PyTorch library (or any custom scheduler + that follows the same API). Default is None, which means no LR decay. Example usage ------------- @@ -57,17 +73,53 @@ class Model(LightningModule, Block): ... trainer.fit(model, Loader(dataset, batch_size=16)) """ - def __init__(self, *blocks: nn.Module, optimizer=torch.optim.Adam, initialization="auto"): + def __init__(self, *blocks: nn.Module, initialization="auto"): super().__init__() # Copied from BlockContainer.__init__ self.values = nn.ModuleList() for module in blocks: self.values.append(self.wrap_module(module)) - - self.optimizer = optimizer self.initialization = initialization + @property + @torch.jit.ignore + def optimizer(self): + return self._optimizer if hasattr(self, "_optimizer") else None + + def configure_optimizers( + self, + optimizer: Optional[OptimizerType] = None, + scheduler: Optional[LRSchedulerType] = None, + ): + """Configures the optimizer for the model.""" + if optimizer is None: + optimizer = self._optimizer if hasattr(self, "_optimizer") else "adam" + self._optimizer = create_optimizer(self, optimizer) + + if scheduler is None: + if hasattr(self, "_scheduler"): + scheduler = self._scheduler + else: + self._scheduler = None + if scheduler is not None: + self._scheduler = get_scheduler(self._optimizer, scheduler) + + if not isinstance(self._optimizer, (list, tuple)): + opt = [self._optimizer] + else: + opt = self._optimizer + + if self._scheduler is not None: + if not isinstance(self._scheduler, (list, tuple)): + sched = [self._scheduler] + else: + sched = self._scheduler + + return opt, sched + + return opt + def initialize(self, data: Union[Dataset, Loader, Batch]): """Initializes the model based on a given data set.""" return module_utils.initialize(self, data, dtype=self._dtype) @@ -119,10 +171,6 @@ def _val_step(self, batch, batch_idx, type="val"): return loss_and_metrics - def configure_optimizers(self): - """Configures the optimizer for the model.""" - return self.optimizer(self.parameters()) - def model_outputs(self) -> List[ModelOutput]: """Finds all instances of `ModelOutput` in the model.""" return self.find(ModelOutput) @@ -376,3 +424,106 @@ def compute_loss( results[metric_name] = metric(_predictions, _targets) return results + + +def create_optimizer(module: nn.Module, opt: OptimizerType) -> optim.Optimizer: + """ + Creates an optimizer given a PyTorch module and an optimizer type. + + Parameters + ---------- + module : torch.nn.Module + The PyTorch model. + opt : str, Type[torch.optim.Optimizer], or torch.optim.Optimizer + The optimizer type, either as a string, a class, or an existing + PyTorch optimizer object. + + Returns + ------- + torch.optim.Optimizer + A PyTorch optimizer. + + Raises + ------ + ValueError + If the provided string for opt does not correspond to a known optimizer type. + TypeError + If the type of opt is neither string, class of torch.optim.Optimizer, + nor instance of torch.optim.Optimizer. + """ + + # Extract the model parameters + params = module.parameters() + + # If opt is a string, create a new optimizer of the given type + if isinstance(opt, str): + if opt.lower() == "sgd": + return optim.SGD(params, lr=0.01) + elif opt.lower() == "adam": + return optim.Adam(params, lr=0.001) + elif opt.lower() == "adagrad": + return optim.Adagrad(params, lr=0.01) + else: + raise ValueError(f"Unsupported optimizer type: {opt}") + + # If opt is an optimizer class, create a new optimizer of the given type + elif isinstance(opt, type) and issubclass(opt, optim.Optimizer): + return opt(params, lr=0.01) + + # If opt is an optimizer instance, create a new optimizer of the same type + elif isinstance(opt, optim.Optimizer): + # Flattens a list of lists (or other iterable) + def flatten(lis: Iterator[Iterator]) -> Iterator: + return list(itertools.chain.from_iterable(lis)) + + # Extract parameters from optimizer's param_groups + params_opt = flatten([group["params"] for group in opt.param_groups]) + params_module = list(module.parameters()) + + # Check if the parameters of the module and the optimizer are the same + if params_module == params_opt: + # If parameters are the same, return the existing optimizer + return opt + else: + # If parameters are not the same, create a new optimizer of the same type + opt_type = type(opt) + return opt_type(params_module, **opt.defaults) + + raise TypeError( + "Expected opt to be a string, a class of torch.optim.Optimizer, ", + f"or an instance of torch.optim.Optimizer, but got {type(opt)}", + ) + + +def get_scheduler(optimizer: optim.Optimizer, scheduler: LRSchedulerType) -> LRScheduler: + """ + Get an instance of a learning rate scheduler. + + Parameters + ---------- + optimizer : torch.optim.Optimizer + The optimizer to which the scheduler should be applied. + scheduler : SchedulerType + The scheduler or scheduler class to use. + If an instance is provided and its optimizer is different from the provided optimizer: + a new instance of the same type is returned with the provided optimizer. + If the optimizers are the same: the original scheduler is returned. + If a class is provided: an instance is created with the optimizer as the only argument. + + Returns + ------- + torch.optim.lr_scheduler._LRScheduler + The scheduler instance. + """ + if isinstance(scheduler, LRScheduler): + if scheduler.optimizer != optimizer: + return type(scheduler)(optimizer) + else: + return scheduler + elif inspect.isclass(scheduler) and issubclass(scheduler, LRScheduler): + return scheduler(optimizer) + + raise TypeError( + "scheduler must be a subclass or instance of optim.lr_scheduler.LRScheduler ", + f"got: {scheduler}", + ) diff --git a/tests/unit/torch/models/test_base.py b/tests/unit/torch/models/test_base.py index 99ab5fae61..366e922cd1 100644 --- a/tests/unit/torch/models/test_base.py +++ b/tests/unit/torch/models/test_base.py @@ -44,13 +44,17 @@ def test_init_default(self): model = mm.Model(mm.Block(), nn.Linear(10, 10)) assert isinstance(model, mm.Model) assert len(model) == 2 - assert model.optimizer is torch.optim.Adam - assert isinstance(model.configure_optimizers(), torch.optim.Adam) + assert isinstance(model.configure_optimizers()[0], torch.optim.Adam) - def test_init_optimizer(self): - optimizer = torch.optim.SGD - model = mm.Model(mm.Block(), mm.Block(), optimizer=optimizer) - assert model.optimizer is torch.optim.SGD + def test_init_optimizer_and_scheduler(self): + model = mm.Model(mm.MLPBlock([4, 4])) + model.initialize(mm.Batch(torch.rand(2, 2))) + + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.99) + opt, sched = model.configure_optimizers(optimizer, scheduler) + assert opt == [optimizer] + assert sched == [scheduler] def test_pre_and_pre(self): inputs = torch.tensor([[1, 2], [3, 4]])