Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making it possible to set optimizer and scheduler instances via PyTorch model properties for training #1187

Merged
merged 7 commits into from
Jul 11, 2023
171 changes: 161 additions & 10 deletions merlin/models/torch/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import inspect
import itertools
import os
from typing import Dict, List, Optional, Sequence, Union
from typing import Dict, Iterator, List, Optional, Sequence, Type, Union

import torch
from packaging import version
from pytorch_lightning import LightningDataModule, LightningModule
from torch import nn
from torch import nn, optim

from merlin.dataloader.torch import Loader
from merlin.io import Dataset
Expand All @@ -28,6 +31,16 @@
from merlin.models.torch.utils import module_utils
from merlin.models.utils.registry import camelcase_to_snakecase

OptimizerType = Union[optim.Optimizer, Type[optim.Optimizer], str]

LRScheduler = (
optim.lr_scheduler._LRScheduler
if version.parse(torch.__version__).major < 2
else optim.lr_scheduler.LRScheduler
)

LRSchedulerType = Union[LRScheduler, Type[LRScheduler]]


class Model(LightningModule, Block):
"""
Expand All @@ -43,8 +56,11 @@ class Model(LightningModule, Block):
schema: Schema, optional
A Merlin schema. Default is None.
optimizer: torch.optim.Optimizer, optional
A PyTorch optimizer from the PyTorch library (or any custom optimizer
A PyTorch optimizer instance or class from the PyTorch library (or any custom optimizer
that follows the same API). Default is Adam optimizer.
scheduler: torch.optim.lr_scheduler.LRScheduler, optional
A PyTorch learning rate scheduler instance from the PyTorch library (or any custom scheduler
that follows the same API). Default is None, which means no LR decay.

Example usage
-------------
Expand All @@ -57,17 +73,53 @@ class Model(LightningModule, Block):
... trainer.fit(model, Loader(dataset, batch_size=16))
"""

def __init__(self, *blocks: nn.Module, optimizer=torch.optim.Adam, initialization="auto"):
def __init__(self, *blocks: nn.Module, initialization="auto"):
super().__init__()

# Copied from BlockContainer.__init__
self.values = nn.ModuleList()
for module in blocks:
self.values.append(self.wrap_module(module))

self.optimizer = optimizer
self.initialization = initialization

@property
@torch.jit.ignore
def optimizer(self):
return self._optimizer if hasattr(self, "_optimizer") else None

def configure_optimizers(
self,
optimizer: Optional[OptimizerType] = None,
scheduler: Optional[LRSchedulerType] = None,
):
"""Configures the optimizer for the model."""
if optimizer is None:
optimizer = self._optimizer if hasattr(self, "_optimizer") else "adam"
self._optimizer = create_optimizer(self, optimizer)

if scheduler is None:
if hasattr(self, "_scheduler"):
scheduler = self._scheduler
else:
self._scheduler = None
if scheduler is not None:
self._scheduler = get_scheduler(self._optimizer, scheduler)

if not isinstance(self._optimizer, (list, tuple)):
opt = [self._optimizer]
else:
opt = self._optimizer

if self._scheduler is not None:
if not isinstance(self._scheduler, (list, tuple)):
sched = [self._scheduler]
else:
sched = self._scheduler

return opt, sched

return opt

def initialize(self, data: Union[Dataset, Loader, Batch]):
"""Initializes the model based on a given data set."""
return module_utils.initialize(self, data, dtype=self._dtype)
Expand Down Expand Up @@ -119,10 +171,6 @@ def _val_step(self, batch, batch_idx, type="val"):

return loss_and_metrics

def configure_optimizers(self):
"""Configures the optimizer for the model."""
return self.optimizer(self.parameters())

def model_outputs(self) -> List[ModelOutput]:
"""Finds all instances of `ModelOutput` in the model."""
return self.find(ModelOutput)
Expand Down Expand Up @@ -376,3 +424,106 @@ def compute_loss(

results[metric_name] = metric(_predictions, _targets)
return results


def create_optimizer(module: nn.Module, opt: OptimizerType) -> optim.Optimizer:
"""
Creates an optimizer given a PyTorch module and an optimizer type.

Parameters
----------
module : torch.nn.Module
The PyTorch model.
opt : str, Type[torch.optim.Optimizer], or torch.optim.Optimizer
The optimizer type, either as a string, a class, or an existing
PyTorch optimizer object.

Returns
-------
torch.optim.Optimizer
A PyTorch optimizer.

Raises
------
ValueError
If the provided string for opt does not correspond to a known optimizer type.
TypeError
If the type of opt is neither string, class of torch.optim.Optimizer,
nor instance of torch.optim.Optimizer.
"""

# Extract the model parameters
params = module.parameters()

# If opt is a string, create a new optimizer of the given type
if isinstance(opt, str):
if opt.lower() == "sgd":
return optim.SGD(params, lr=0.01)
elif opt.lower() == "adam":
return optim.Adam(params, lr=0.001)
elif opt.lower() == "adagrad":
return optim.Adagrad(params, lr=0.01)
else:
raise ValueError(f"Unsupported optimizer type: {opt}")

# If opt is an optimizer class, create a new optimizer of the given type
elif isinstance(opt, type) and issubclass(opt, optim.Optimizer):
return opt(params, lr=0.01)

# If opt is an optimizer instance, create a new optimizer of the same type
elif isinstance(opt, optim.Optimizer):
# Flattens a list of lists (or other iterable)
def flatten(lis: Iterator[Iterator]) -> Iterator:
return list(itertools.chain.from_iterable(lis))

# Extract parameters from optimizer's param_groups
params_opt = flatten([group["params"] for group in opt.param_groups])
params_module = list(module.parameters())

# Check if the parameters of the module and the optimizer are the same
if params_module == params_opt:
# If parameters are the same, return the existing optimizer
return opt
else:
# If parameters are not the same, create a new optimizer of the same type
opt_type = type(opt)
return opt_type(params_module, **opt.defaults)

raise TypeError(
"Expected opt to be a string, a class of torch.optim.Optimizer, ",
f"or an instance of torch.optim.Optimizer, but got {type(opt)}",
)


def get_scheduler(optimizer: optim.Optimizer, scheduler: LRSchedulerType) -> LRScheduler:
"""
Get an instance of a learning rate scheduler.

Parameters
----------
optimizer : torch.optim.Optimizer
The optimizer to which the scheduler should be applied.
scheduler : SchedulerType
The scheduler or scheduler class to use.
If an instance is provided and its optimizer is different from the provided optimizer:
a new instance of the same type is returned with the provided optimizer.
If the optimizers are the same: the original scheduler is returned.
If a class is provided: an instance is created with the optimizer as the only argument.

Returns
-------
torch.optim.lr_scheduler._LRScheduler
The scheduler instance.
"""
if isinstance(scheduler, LRScheduler):
if scheduler.optimizer != optimizer:
return type(scheduler)(optimizer)
else:
return scheduler
elif inspect.isclass(scheduler) and issubclass(scheduler, LRScheduler):
return scheduler(optimizer)

raise TypeError(
"scheduler must be a subclass or instance of optim.lr_scheduler.LRScheduler ",
f"got: {scheduler}",
)
16 changes: 10 additions & 6 deletions tests/unit/torch/models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,17 @@ def test_init_default(self):
model = mm.Model(mm.Block(), nn.Linear(10, 10))
assert isinstance(model, mm.Model)
assert len(model) == 2
assert model.optimizer is torch.optim.Adam
assert isinstance(model.configure_optimizers(), torch.optim.Adam)
assert isinstance(model.configure_optimizers()[0], torch.optim.Adam)

def test_init_optimizer(self):
optimizer = torch.optim.SGD
model = mm.Model(mm.Block(), mm.Block(), optimizer=optimizer)
assert model.optimizer is torch.optim.SGD
def test_init_optimizer_and_scheduler(self):
model = mm.Model(mm.MLPBlock([4, 4]))
model.initialize(mm.Batch(torch.rand(2, 2)))

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.99)
opt, sched = model.configure_optimizers(optimizer, scheduler)
assert opt == [optimizer]
assert sched == [scheduler]

def test_pre_and_pre(self):
inputs = torch.tensor([[1, 2], [3, 4]])
Expand Down