From 8bc7b0115a7b4561956dda645d87d1f64c0ba40c Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 9 May 2023 20:09:34 +0200 Subject: [PATCH 01/53] hparam optimization with ray[tune] --- pyproject.toml | 4 +- .../trainable/trainable.py | 53 ++++++++++++++++++- 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2004e9e..32ff79a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,9 @@ tests = [ ] experiments = [ - # these are just recommended packages to run experiments + # required + "ray[tune] ~= 2.4", + # recommended "numpy ~= 1.24", "matplotlib ~= 3.7", "jupyterlab ~= 3.6", diff --git a/src/lightning_trainable/trainable/trainable.py b/src/lightning_trainable/trainable/trainable.py index 45381e2..a16d752 100644 --- a/src/lightning_trainable/trainable/trainable.py +++ b/src/lightning_trainable/trainable/trainable.py @@ -376,7 +376,7 @@ def fit_fast(self, device="cuda"): batch = tuple(t.to(device) for t in batch if torch.is_tensor(t)) optimizer.zero_grad() - loss = self.training_step(batch, 0) + loss = self.training_step(batch, batch_idx) loss.backward() optimizer.step() @@ -388,6 +388,57 @@ def load_checkpoint(cls, root: str | pathlib.Path = "lightning_logs", version: i checkpoint = utils.find_checkpoint(root, version, epoch, step) return cls.load_from_checkpoint(checkpoint, **kwargs) + @classmethod + def optimize_hparams(cls, hparams: dict, scheduler: str = "asha", model_kwargs: dict = None, tune_kwargs: dict = None): + """ Optimize the HParams with ray[tune] """ + try: + from ray import tune + except ModuleNotFoundError: + raise ModuleNotFoundError(f"Please install lightning-trainable[experiments] to use `optimize_hparams`.") + + if model_kwargs is None: + model_kwargs = dict() + if tune_kwargs is None: + tune_kwargs = dict() + + def train(hparams): + model = cls(hparams, **model_kwargs) + model.fit( + logger_kwargs=dict(save_dir=tune.get_trial_dir()), + trainer_kwargs=dict(enable_progress_bar=False) + ) + + return model + + match scheduler: + case "asha": + scheduler = tune.schedulers.AsyncHyperBandScheduler( + time_attr="time_total_s", + max_t=600, + grace_period=180, + reduction_factor=4, + ) + case other: + raise NotImplementedError(f"Unrecognized scheduler: '{other}'") + + reporter = tune.JupyterNotebookReporter( + overwrite=True, + parameter_columns=list(hparams.keys()), + metric_columns=[f"training/{cls.hparams.loss}", f"validation/{cls.hparams.loss}"], + ) + + analysis = tune.run( + train, + metric=f"validation/{cls.hparams.loss}", + mode="min", + config=hparams, + scheduler=scheduler, + progress_reporter=reporter, + **tune_kwargs + ) + + return analysis + def auto_pin_memory(pin_memory: bool | None, accelerator: str): if pin_memory is None: From 5dd0aadb8583d14dcf926278c38d5efd665c02c7 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 9 May 2023 20:48:39 +0200 Subject: [PATCH 02/53] fix working directory --- src/lightning_trainable/trainable/trainable.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/lightning_trainable/trainable/trainable.py b/src/lightning_trainable/trainable/trainable.py index a16d752..18accd5 100644 --- a/src/lightning_trainable/trainable/trainable.py +++ b/src/lightning_trainable/trainable/trainable.py @@ -393,6 +393,7 @@ def optimize_hparams(cls, hparams: dict, scheduler: str = "asha", model_kwargs: """ Optimize the HParams with ray[tune] """ try: from ray import tune + import os except ModuleNotFoundError: raise ModuleNotFoundError(f"Please install lightning-trainable[experiments] to use `optimize_hparams`.") @@ -400,11 +401,14 @@ def optimize_hparams(cls, hparams: dict, scheduler: str = "asha", model_kwargs: model_kwargs = dict() if tune_kwargs is None: tune_kwargs = dict() + + path = os.getcwd() def train(hparams): + os.chdir(path) model = cls(hparams, **model_kwargs) model.fit( - logger_kwargs=dict(save_dir=tune.get_trial_dir()), + logger_kwargs=dict(save_dir=str(tune.get_trial_dir())), trainer_kwargs=dict(enable_progress_bar=False) ) @@ -424,12 +428,12 @@ def train(hparams): reporter = tune.JupyterNotebookReporter( overwrite=True, parameter_columns=list(hparams.keys()), - metric_columns=[f"training/{cls.hparams.loss}", f"validation/{cls.hparams.loss}"], + metric_columns=[f"training/{cls.hparams_type.loss}", f"validation/{cls.hparams_type.loss}"], ) analysis = tune.run( train, - metric=f"validation/{cls.hparams.loss}", + metric=f"validation/{cls.hparams_type.loss}", mode="min", config=hparams, scheduler=scheduler, From f8c82eb68e174fd21c127653f5e09b4ca937367b Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Wed, 10 May 2023 15:36:10 +0200 Subject: [PATCH 03/53] Fix onecyclelr for accumulate_batches > 1 --- src/lightning_trainable/trainable/trainable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_trainable/trainable/trainable.py b/src/lightning_trainable/trainable/trainable.py index 18accd5..0617cd0 100644 --- a/src/lightning_trainable/trainable/trainable.py +++ b/src/lightning_trainable/trainable/trainable.py @@ -104,7 +104,7 @@ def configure_lr_schedulers(self, optimizer): kwargs = dict( max_lr=optimizer.defaults["lr"], epochs=self.hparams.max_epochs, - steps_per_epoch=len(self.train_dataloader()) + steps_per_epoch=len(self.train_dataloader()) // self.hparams.accumulate_batches, ) interval = "step" case _: From 365164c9edec1a3b9c518ca6edb8630a1487bae1 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Thu, 11 May 2023 14:16:55 +0200 Subject: [PATCH 04/53] add error (1 - accuracy) --- src/lightning_trainable/metrics/__init__.py | 1 + src/lightning_trainable/metrics/error.py | 6 ++++++ 2 files changed, 7 insertions(+) create mode 100644 src/lightning_trainable/metrics/error.py diff --git a/src/lightning_trainable/metrics/__init__.py b/src/lightning_trainable/metrics/__init__.py index 06a1dab..9f303c3 100644 --- a/src/lightning_trainable/metrics/__init__.py +++ b/src/lightning_trainable/metrics/__init__.py @@ -1 +1,2 @@ from .accuracy import accuracy +from .error import error diff --git a/src/lightning_trainable/metrics/error.py b/src/lightning_trainable/metrics/error.py new file mode 100644 index 0000000..ec4b458 --- /dev/null +++ b/src/lightning_trainable/metrics/error.py @@ -0,0 +1,6 @@ + +from .accuracy import accuracy + + +def error(logits, targets, *, k=1): + return 1.0 - accuracy(logits, targets, k=k) From 1452cadee7c6ba73f00eb713b5495922cb4b2d0c Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Thu, 11 May 2023 14:17:02 +0200 Subject: [PATCH 05/53] add deprecate --- src/lightning_trainable/utils/__init__.py | 1 + src/lightning_trainable/utils/deprecate.py | 6 ++++++ 2 files changed, 7 insertions(+) create mode 100644 src/lightning_trainable/utils/deprecate.py diff --git a/src/lightning_trainable/utils/__init__.py b/src/lightning_trainable/utils/__init__.py index daabbf4..2ffc560 100644 --- a/src/lightning_trainable/utils/__init__.py +++ b/src/lightning_trainable/utils/__init__.py @@ -1,3 +1,4 @@ +from .deprecate import deprecate from .io import find_checkpoint from .iteration import flatten, zip from .modules import ( diff --git a/src/lightning_trainable/utils/deprecate.py b/src/lightning_trainable/utils/deprecate.py new file mode 100644 index 0000000..bc58fa9 --- /dev/null +++ b/src/lightning_trainable/utils/deprecate.py @@ -0,0 +1,6 @@ + +import warnings + + +def deprecate(message: str): + warnings.warn(message, DeprecationWarning, stacklevel=2) From 3826c4503dc9ca8e6660fe714d7517228960c6f4 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Thu, 11 May 2023 14:17:41 +0200 Subject: [PATCH 06/53] rework configure_optimizers and configure_lr_schedulers --- .../trainable/lr_schedulers/__init__.py | 1 + .../trainable/lr_schedulers/configure.py | 40 ++++++++ .../trainable/lr_schedulers/defaults.py | 16 ++++ .../trainable/optimizers/__init__.py | 1 + .../trainable/optimizers/configure.py | 38 ++++++++ .../trainable/optimizers/defaults.py | 7 ++ .../trainable/trainable.py | 96 +++---------------- .../trainable/trainable_hparams.py | 34 +++++++ tests/test_trainable.py | 4 +- 9 files changed, 150 insertions(+), 87 deletions(-) create mode 100644 src/lightning_trainable/trainable/lr_schedulers/__init__.py create mode 100644 src/lightning_trainable/trainable/lr_schedulers/configure.py create mode 100644 src/lightning_trainable/trainable/lr_schedulers/defaults.py create mode 100644 src/lightning_trainable/trainable/optimizers/__init__.py create mode 100644 src/lightning_trainable/trainable/optimizers/configure.py create mode 100644 src/lightning_trainable/trainable/optimizers/defaults.py diff --git a/src/lightning_trainable/trainable/lr_schedulers/__init__.py b/src/lightning_trainable/trainable/lr_schedulers/__init__.py new file mode 100644 index 0000000..9234965 --- /dev/null +++ b/src/lightning_trainable/trainable/lr_schedulers/__init__.py @@ -0,0 +1 @@ +from .configure import configure diff --git a/src/lightning_trainable/trainable/lr_schedulers/configure.py b/src/lightning_trainable/trainable/lr_schedulers/configure.py new file mode 100644 index 0000000..89c0307 --- /dev/null +++ b/src/lightning_trainable/trainable/lr_schedulers/configure.py @@ -0,0 +1,40 @@ + +from torch.optim.lr_scheduler import LRScheduler + +from lightning_trainable.utils import get_scheduler + +from . import defaults + + +def with_kwargs(model, optimizer, **kwargs) -> LRScheduler: + """ + Get a learning rate scheduler with the given kwargs + Insert default values for missing kwargs + @param model: Trainable model + @param optimizer: The optimizer to use with the scheduler + @param kwargs: Keyword arguments for the scheduler + @return: The configured learning rate scheduler + """ + name = kwargs.pop("name") + default_kwargs = defaults.get_defaults(name, model, optimizer) + kwargs = default_kwargs | kwargs + return get_scheduler(name)(optimizer, **kwargs) + + +def configure(model, optimizer) -> LRScheduler | None: + """ + Configure a learning rate scheduler from the model's hparams + @param model: Trainable model + @param optimizer: The optimizer to use with the scheduler + @return: The configured learning rate scheduler + """ + match model.hparams.lr_scheduler: + case str() as name: + return with_kwargs(model, optimizer, name=name) + case dict() as kwargs: + return with_kwargs(model, optimizer, **kwargs.copy()) + case None: + # do not use a learning rate scheduler + return None + case other: + raise NotImplementedError(f"Unrecognized Scheduler: '{other}'") diff --git a/src/lightning_trainable/trainable/lr_schedulers/defaults.py b/src/lightning_trainable/trainable/lr_schedulers/defaults.py new file mode 100644 index 0000000..18a672f --- /dev/null +++ b/src/lightning_trainable/trainable/lr_schedulers/defaults.py @@ -0,0 +1,16 @@ + +def get_defaults(scheduler_name, model, optimizer): + match scheduler_name: + case "OneCycleLR": + max_lr = optimizer.defaults["lr"] + total_steps = model.hparams.max_steps + if total_steps == -1: + total_steps = model.hparams.max_epochs * len(model.train_dataloader()) + total_steps = int(total_steps / model.hparams.accumulate_batches) + + return dict( + max_lr=max_lr, + total_steps=total_steps + ) + case _: + return dict() diff --git a/src/lightning_trainable/trainable/optimizers/__init__.py b/src/lightning_trainable/trainable/optimizers/__init__.py new file mode 100644 index 0000000..9234965 --- /dev/null +++ b/src/lightning_trainable/trainable/optimizers/__init__.py @@ -0,0 +1 @@ +from .configure import configure diff --git a/src/lightning_trainable/trainable/optimizers/configure.py b/src/lightning_trainable/trainable/optimizers/configure.py new file mode 100644 index 0000000..3828d13 --- /dev/null +++ b/src/lightning_trainable/trainable/optimizers/configure.py @@ -0,0 +1,38 @@ + +from torch.optim import Optimizer + +from lightning_trainable.utils import get_optimizer + +from . import defaults + + +def with_kwargs(model, **kwargs) -> Optimizer: + """ + Get an optimizer with the given kwargs + Insert default values for missing kwargs + @param model: Trainable model + @param kwargs: Keyword arguments for the optimizer + @return: The configured optimizer + """ + name = kwargs.pop("name") + default_kwargs = defaults.get_defaults(name, model) + kwargs = default_kwargs | kwargs + return get_optimizer(name)(**kwargs) + + +def configure(model) -> Optimizer | None: + """ + Configure an optimizer from the model's hparams + @param model: Trainable model + @return: The configured optimizer + """ + match model.hparams.optimizer: + case str() as name: + return with_kwargs(model, name=name) + case dict() as kwargs: + return with_kwargs(model, **kwargs.copy()) + case None: + # do not use an optimizer + return None + case other: + raise NotImplementedError(f"Unrecognized Optimizer: '{other}'") diff --git a/src/lightning_trainable/trainable/optimizers/defaults.py b/src/lightning_trainable/trainable/optimizers/defaults.py new file mode 100644 index 0000000..5b9b23c --- /dev/null +++ b/src/lightning_trainable/trainable/optimizers/defaults.py @@ -0,0 +1,7 @@ + +def get_defaults(optimizer_name, model): + match optimizer_name: + case _: + return dict( + params=model.parameters(), + ) diff --git a/src/lightning_trainable/trainable/trainable.py b/src/lightning_trainable/trainable/trainable.py index 0617cd0..d23e5ab 100644 --- a/src/lightning_trainable/trainable/trainable.py +++ b/src/lightning_trainable/trainable/trainable.py @@ -13,6 +13,9 @@ from lightning_trainable.callbacks import EpochProgressBar from .trainable_hparams import TrainableHParams +from . import lr_schedulers +from . import optimizers + class SkipBatch(Exception): pass @@ -87,95 +90,18 @@ def test_step(self, batch, batch_idx): self.log(f"test/{key}", value, prog_bar=key == self.hparams.loss) - def configure_lr_schedulers(self, optimizer): - """ - Configure the LR Scheduler as defined in HParams. - By default, we only use a single LR Scheduler, attached to a single optimizer. - You can use a ChainedScheduler if you need multiple LR Schedulers throughout training, - or override this method if you need different schedulers for different parameters. - - @param optimizer: The optimizer to attach the scheduler to. - @return: The LR Scheduler object. - """ - match self.hparams.lr_scheduler: - case str() as name: - match name.lower(): - case "onecyclelr": - kwargs = dict( - max_lr=optimizer.defaults["lr"], - epochs=self.hparams.max_epochs, - steps_per_epoch=len(self.train_dataloader()) // self.hparams.accumulate_batches, - ) - interval = "step" - case _: - kwargs = dict() - interval = "step" - scheduler = utils.get_scheduler(name)(optimizer, **kwargs) - return dict( - scheduler=scheduler, - interval=interval, - ) - case dict() as kwargs: - # Copy the dict so we don't modify the original - kwargs = deepcopy(kwargs) - - name = kwargs.pop("name") - interval = "step" - if "interval" in kwargs: - interval = kwargs.pop("interval") - scheduler = utils.get_scheduler(name)(optimizer, **kwargs) - return dict( - scheduler=scheduler, - interval=interval, - ) - case type(torch.optim.lr_scheduler.LRScheduler) as Scheduler: - kwargs = dict() - interval = "step" - scheduler = Scheduler(optimizer, **kwargs) - return dict( - scheduler=scheduler, - interval=interval, - ) - case (torch.optim.lr_scheduler.LRScheduler() | torch.optim.lr_scheduler.ReduceLROnPlateau()) as scheduler: - return dict( - scheduler=scheduler, - interval="step", - ) - case None: - # do not use a scheduler - return None - case other: - raise NotImplementedError(f"Unrecognized Scheduler: {other}") - def configure_optimizers(self): """ - Configure Optimizer and LR Scheduler objects as defined in HParams. - By default, we only use a single optimizer and an optional LR Scheduler. - If you need multiple optimizers, override this method. + Configure the optimizer and learning rate scheduler for this model, based on the HParams. + This method is called automatically by the Lightning Trainer in module fitting. - @return: A dictionary containing the optimizer and lr_scheduler. - """ - kwargs = dict() - - match self.hparams.optimizer: - case str() as name: - optimizer = utils.get_optimizer(name)(self.parameters(), **kwargs) - case dict() as kwargs: - # Copy the dict so we don't modify the original - kwargs = deepcopy(kwargs) - - name = kwargs.pop("name") - optimizer = utils.get_optimizer(name)(self.parameters(), **kwargs) - case type(torch.optim.Optimizer) as Optimizer: - optimizer = Optimizer(self.parameters(), **kwargs) - case torch.optim.Optimizer() as optimizer: - pass - case None: - return None - case other: - raise NotImplementedError(f"Unrecognized Optimizer: {other}") + By default, we use one optimizer and zero or one learning rate scheduler. + If you want to use multiple optimizers or learning rate schedulers, you must override this method. - lr_scheduler = self.configure_lr_schedulers(optimizer) + @return: A dictionary containing the optimizer and learning rate scheduler, if any. + """ + optimizer = optimizers.configure(self) + lr_scheduler = lr_schedulers.configure(self, optimizer) if lr_scheduler is None: return optimizer diff --git a/src/lightning_trainable/trainable/trainable_hparams.py b/src/lightning_trainable/trainable/trainable_hparams.py index 32b6375..dc8fc13 100644 --- a/src/lightning_trainable/trainable/trainable_hparams.py +++ b/src/lightning_trainable/trainable/trainable_hparams.py @@ -1,6 +1,10 @@ +import warnings + from lightning_trainable.hparams import HParams from lightning.pytorch.profilers import Profiler +from lightning_trainable.utils import deprecate + class TrainableHParams(HParams): @@ -25,5 +29,35 @@ class TrainableHParams(HParams): @classmethod def _migrate_hparams(cls, hparams): if "accumulate_batches" in hparams and hparams["accumulate_batches"] is None: + deprecate("accumulate_batches changed default value: None -> 1") hparams["accumulate_batches"] = 1 + + if "optimizer" in hparams: + match hparams["optimizer"]: + case str() as name: + if name == name.lower(): + deprecate("optimizer name is now case-sensitive.") + if name == "adam": + hparams["optimizer"] = "Adam" + case dict() as kwargs: + name = kwargs["name"] + if name == name.lower(): + deprecate("optimizer name is now case-sensitive.") + if name == "adam": + hparams["optimizer"]["name"] = "Adam" + + if "lr_scheduler" in hparams: + match hparams["lr_scheduler"]: + case str() as name: + if name == name.lower(): + deprecate("lr_scheduler name is now case-sensitive.") + if name == "onecyclelr": + hparams["lr_scheduler"] = "OneCycleLR" + case dict() as kwargs: + name = kwargs["name"] + if name == name.lower(): + deprecate("lr_scheduler name is now case-sensitive.") + if name == "onecyclelr": + hparams["lr_scheduler"]["name"] = "OneCycleLR" + return hparams diff --git a/tests/test_trainable.py b/tests/test_trainable.py index b5e8ab3..41ade21 100644 --- a/tests/test_trainable.py +++ b/tests/test_trainable.py @@ -30,7 +30,7 @@ def compute_metrics(self, batch, batch_idx) -> dict: accelerator="cpu", max_epochs=10, batch_size=32, - lr_scheduler="onecyclelr" + lr_scheduler="OneCycleLR" ) model = SimpleTrainable(hparams, train_data=train_data) model.fit() @@ -56,7 +56,7 @@ def compute_metrics(self, batch, batch_idx) -> dict: max_epochs=1, batch_size=8, optimizer=dict( - name="adam", + name="Adam", lr=1e-3, ) ) From 4c02212f1725ccb6211d4d66500950624788f272 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 12 May 2023 20:42:55 +0200 Subject: [PATCH 07/53] add wasserstein distance to metrics --- .../metrics/wasserstein.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 src/lightning_trainable/metrics/wasserstein.py diff --git a/src/lightning_trainable/metrics/wasserstein.py b/src/lightning_trainable/metrics/wasserstein.py new file mode 100644 index 0000000..cf533db --- /dev/null +++ b/src/lightning_trainable/metrics/wasserstein.py @@ -0,0 +1,37 @@ + +import torch + + +def wasserstein( + x: torch.Tensor, + y: torch.Tensor, + cost: torch.Tensor = None, + epsilon: int | float = 0.1, + steps: int = 100 +) -> torch.Tensor: + """ + Compute the Wasserstein distance between two distributions. + @param x: Samples from the first distribution. + @param y: Samples from the second distribution. + @param cost: Optional cost matrix. If not provided, the L2 distance is used. + @param epsilon: The entropic regularization parameter. + @param steps: The number of Sinkhorn iterations. + """ + if cost is None: + cost = x[:, None] - y[None, :] + cost = torch.flatten(cost, start_dim=2) + cost = torch.linalg.norm(cost, dim=-1) + + if cost.shape != (x.shape[0], y.shape[0]): + raise ValueError(f"Expected cost matrix of shape {(x.shape[0], y.shape[0])}, but got {cost.shape}.") + + u = torch.zeros(x.shape[0], device=x.device) + v = torch.zeros(y.shape[0], device=y.device) + + for step in range(steps): + u = epsilon * torch.logsumexp(-cost + v[None, :] / epsilon, dim=1) + v = epsilon * torch.logsumexp(-cost + u[:, None] / epsilon, dim=0) + + w = torch.sum(u * torch.sum(cost * torch.exp(-(u[:, None] + v[None, :]) / epsilon), dim=1), dim=0) + + return w From af148eeb77125f78f249c52e91edcc6775092b86 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 12 May 2023 20:44:23 +0200 Subject: [PATCH 08/53] shorten wasserstein signature --- src/lightning_trainable/metrics/wasserstein.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/lightning_trainable/metrics/wasserstein.py b/src/lightning_trainable/metrics/wasserstein.py index cf533db..d020230 100644 --- a/src/lightning_trainable/metrics/wasserstein.py +++ b/src/lightning_trainable/metrics/wasserstein.py @@ -1,14 +1,9 @@ import torch +from torch import Tensor -def wasserstein( - x: torch.Tensor, - y: torch.Tensor, - cost: torch.Tensor = None, - epsilon: int | float = 0.1, - steps: int = 100 -) -> torch.Tensor: +def wasserstein(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: int | float = 0.1, steps: int = 100) -> Tensor: """ Compute the Wasserstein distance between two distributions. @param x: Samples from the first distribution. From b55556c870f5b232b56295746995ec474f0dfe62 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 12 May 2023 20:45:37 +0200 Subject: [PATCH 09/53] more error handling --- src/lightning_trainable/metrics/wasserstein.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/lightning_trainable/metrics/wasserstein.py b/src/lightning_trainable/metrics/wasserstein.py index d020230..55b5fd7 100644 --- a/src/lightning_trainable/metrics/wasserstein.py +++ b/src/lightning_trainable/metrics/wasserstein.py @@ -12,12 +12,14 @@ def wasserstein(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: int | float @param epsilon: The entropic regularization parameter. @param steps: The number of Sinkhorn iterations. """ + if x.shape[1:] != y.shape[1:]: + raise ValueError(f"x and y must have the same feature dimensions, but got {x.shape[1:]} and {y.shape[1:]}.") + if cost is None: cost = x[:, None] - y[None, :] cost = torch.flatten(cost, start_dim=2) cost = torch.linalg.norm(cost, dim=-1) - - if cost.shape != (x.shape[0], y.shape[0]): + elif cost.shape != (x.shape[0], y.shape[0]): raise ValueError(f"Expected cost matrix of shape {(x.shape[0], y.shape[0])}, but got {cost.shape}.") u = torch.zeros(x.shape[0], device=x.device) From c8e459fe12c412fa8bf6579dda21985b42e459b0 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 12 May 2023 20:51:24 +0200 Subject: [PATCH 10/53] add import --- src/lightning_trainable/metrics/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning_trainable/metrics/__init__.py b/src/lightning_trainable/metrics/__init__.py index 9f303c3..d94ccf1 100644 --- a/src/lightning_trainable/metrics/__init__.py +++ b/src/lightning_trainable/metrics/__init__.py @@ -1,2 +1,3 @@ from .accuracy import accuracy from .error import error +from .wasserstein import wasserstein From 49f58321494af9d6062104b93f3ca9e735c3a30b Mon Sep 17 00:00:00 2001 From: LarsKue Date: Fri, 12 May 2023 21:14:48 +0200 Subject: [PATCH 11/53] update with optimal transport plan pi --- src/lightning_trainable/metrics/wasserstein.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/lightning_trainable/metrics/wasserstein.py b/src/lightning_trainable/metrics/wasserstein.py index 55b5fd7..df376fa 100644 --- a/src/lightning_trainable/metrics/wasserstein.py +++ b/src/lightning_trainable/metrics/wasserstein.py @@ -3,9 +3,9 @@ from torch import Tensor -def wasserstein(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: int | float = 0.1, steps: int = 100) -> Tensor: +def plan(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: int | float = 0.1, steps: int = 100) -> Tensor: """ - Compute the Wasserstein distance between two distributions. + Compute the optimal transport plan pi between two distributions. @param x: Samples from the first distribution. @param y: Samples from the second distribution. @param cost: Optional cost matrix. If not provided, the L2 distance is used. @@ -29,6 +29,17 @@ def wasserstein(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: int | float u = epsilon * torch.logsumexp(-cost + v[None, :] / epsilon, dim=1) v = epsilon * torch.logsumexp(-cost + u[:, None] / epsilon, dim=0) - w = torch.sum(u * torch.sum(cost * torch.exp(-(u[:, None] + v[None, :]) / epsilon), dim=1), dim=0) + pi = torch.exp(-(cost + u[:, None] + v[None, :]) / epsilon) + + return pi + + +def wasserstein(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: int | float = 0.1, steps: int = 100) -> Tensor: + """ + Compute the Wasserstein distance between two distributions. See plan for parameter descriptions. + """ + pi = plan(x, y, cost, epsilon, steps) + + w = torch.sum(pi * cost) return w From 4ff8b2c8fd2bf2d962f84d78beb344e18c94d8a4 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sat, 13 May 2023 00:29:46 +0200 Subject: [PATCH 12/53] fix wasserstein and sinkhorn --- .../metrics/wasserstein.py | 68 ++++++++++++------- 1 file changed, 44 insertions(+), 24 deletions(-) diff --git a/src/lightning_trainable/metrics/wasserstein.py b/src/lightning_trainable/metrics/wasserstein.py index df376fa..23b6402 100644 --- a/src/lightning_trainable/metrics/wasserstein.py +++ b/src/lightning_trainable/metrics/wasserstein.py @@ -3,43 +3,63 @@ from torch import Tensor -def plan(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: int | float = 0.1, steps: int = 100) -> Tensor: - """ - Compute the optimal transport plan pi between two distributions. - @param x: Samples from the first distribution. - @param y: Samples from the second distribution. - @param cost: Optional cost matrix. If not provided, the L2 distance is used. - @param epsilon: The entropic regularization parameter. - @param steps: The number of Sinkhorn iterations. - """ +def _process_cost(x: Tensor, y: Tensor, cost: Tensor = None) -> Tensor: if x.shape[1:] != y.shape[1:]: - raise ValueError(f"x and y must have the same feature dimensions, but got {x.shape[1:]} and {y.shape[1:]}.") - + raise ValueError(f"Expected x and y to live in the same feature space, " + f"but got {x.shape[1:]} and {y.shape[1:]}.") if cost is None: cost = x[:, None] - y[None, :] cost = torch.flatten(cost, start_dim=2) cost = torch.linalg.norm(cost, dim=-1) elif cost.shape != (x.shape[0], y.shape[0]): - raise ValueError(f"Expected cost matrix of shape {(x.shape[0], y.shape[0])}, but got {cost.shape}.") + raise ValueError(f"Expected cost to have shape {(x.shape[0], y.shape[0])}, " + f"but got {cost.shape}.") + + return cost - u = torch.zeros(x.shape[0], device=x.device) - v = torch.zeros(y.shape[0], device=y.device) - for step in range(steps): - u = epsilon * torch.logsumexp(-cost + v[None, :] / epsilon, dim=1) - v = epsilon * torch.logsumexp(-cost + u[:, None] / epsilon, dim=0) +def sinkhorn(a: Tensor, b: Tensor, cost: Tensor, epsilon: float = 0.1, steps: int = 100) -> Tensor: + """ + Computes the Sinkhorn optimal transport plan between two distributions. + @param a: Sample weights from the first distribution in shape (n,) + @param b: Sample weights from the second distribution in shape (m,) + @param cost: Cost matrix in shape (n, m). + @param epsilon: Entropic regularization parameter. + @param steps: Number of iterations. + """ + if cost.shape != (len(a), len(b)): + raise ValueError(f"Expected cost to have shape {(len(a), len(b))}, but got {cost.shape}.") + + gain = torch.exp(-cost / epsilon) - pi = torch.exp(-(cost + u[:, None] + v[None, :]) / epsilon) + # Initialize the dual variables. + u = torch.ones(len(a)) + v = torch.ones(len(b)) - return pi + # Compute the Sinkhorn iterations. + for _ in range(steps): + v = b / (torch.matmul(gain.T, u) + 1e-50) + u = a / (torch.matmul(gain, v) + 1e-50) + # Return the transport plan. + return u[:, None] * gain * v[None, :] -def wasserstein(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: int | float = 0.1, steps: int = 100) -> Tensor: + +def wasserstein(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = 0.1, steps: int = 100) -> Tensor: """ - Compute the Wasserstein distance between two distributions. See plan for parameter descriptions. + Computes the Wasserstein distance between two distributions. See sinkhorn for more details. """ - pi = plan(x, y, cost, epsilon, steps) + if x.shape[1:] != y.shape[1:]: + raise ValueError(f"Expected x and y to live in the same feature space, " + f"but got {x.shape[1:]} and {y.shape[1:]}.") + if cost is None: + cost = x[:, None] - y[None, :] + cost = torch.flatten(cost, start_dim=2) + cost = torch.linalg.norm(cost, dim=-1) - w = torch.sum(pi * cost) + # Initialize the sample weights. + a = torch.ones(len(x)) / len(x) + b = torch.ones(len(y)) / len(y) - return w + # Compute the transport plan and return the Wasserstein distance. + return torch.sum(sinkhorn(a, b, cost, epsilon, steps) * cost) From b1fe32229e389c51e4640aab15828b308722685b Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sat, 13 May 2023 00:34:49 +0200 Subject: [PATCH 13/53] add sinkhorn_auto, clean up --- .../metrics/wasserstein.py | 38 +++++++++---------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/src/lightning_trainable/metrics/wasserstein.py b/src/lightning_trainable/metrics/wasserstein.py index 23b6402..8213970 100644 --- a/src/lightning_trainable/metrics/wasserstein.py +++ b/src/lightning_trainable/metrics/wasserstein.py @@ -3,24 +3,9 @@ from torch import Tensor -def _process_cost(x: Tensor, y: Tensor, cost: Tensor = None) -> Tensor: - if x.shape[1:] != y.shape[1:]: - raise ValueError(f"Expected x and y to live in the same feature space, " - f"but got {x.shape[1:]} and {y.shape[1:]}.") - if cost is None: - cost = x[:, None] - y[None, :] - cost = torch.flatten(cost, start_dim=2) - cost = torch.linalg.norm(cost, dim=-1) - elif cost.shape != (x.shape[0], y.shape[0]): - raise ValueError(f"Expected cost to have shape {(x.shape[0], y.shape[0])}, " - f"but got {cost.shape}.") - - return cost - - def sinkhorn(a: Tensor, b: Tensor, cost: Tensor, epsilon: float = 0.1, steps: int = 100) -> Tensor: """ - Computes the Sinkhorn optimal transport plan between two distributions. + Computes the Sinkhorn optimal transport plan from sample weights of two distributions. @param a: Sample weights from the first distribution in shape (n,) @param b: Sample weights from the second distribution in shape (m,) @param cost: Cost matrix in shape (n, m). @@ -45,9 +30,15 @@ def sinkhorn(a: Tensor, b: Tensor, cost: Tensor, epsilon: float = 0.1, steps: in return u[:, None] * gain * v[None, :] -def wasserstein(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = 0.1, steps: int = 100) -> Tensor: +def sinkhorn_auto(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = 0.1, steps: int = 100) -> Tensor: """ - Computes the Wasserstein distance between two distributions. See sinkhorn for more details. + Computes the Sinkhorn optimal transport plan from samples from two distributions. + See also: sinkhorn + @param x: Samples from the first distribution in shape (n, ...). + @param y: Samples from the second distribution in shape (m, ...). + @param cost: Optional cost matrix in shape (n, m). If not provided, the Euclidean distance is used. + @param epsilon: Entropic regularization parameter. + @param steps: Number of iterations. """ if x.shape[1:] != y.shape[1:]: raise ValueError(f"Expected x and y to live in the same feature space, " @@ -61,5 +52,12 @@ def wasserstein(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = 0.1, a = torch.ones(len(x)) / len(x) b = torch.ones(len(y)) / len(y) - # Compute the transport plan and return the Wasserstein distance. - return torch.sum(sinkhorn(a, b, cost, epsilon, steps) * cost) + return sinkhorn(a, b, cost, epsilon, steps) + + +def wasserstein(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = 0.1, steps: int = 100) -> Tensor: + """ + Computes the Wasserstein distance between two distributions. + See also: sinkhorn_auto + """ + return torch.sum(sinkhorn_auto(x, y, cost, epsilon, steps) * cost) From e40e4d49f0308cb5e268ab8fd5f3767373106d4c Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sat, 13 May 2023 00:53:30 +0200 Subject: [PATCH 14/53] add automatic bandwidth determination --- src/lightning_trainable/metrics/wasserstein.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/lightning_trainable/metrics/wasserstein.py b/src/lightning_trainable/metrics/wasserstein.py index 8213970..ce67cea 100644 --- a/src/lightning_trainable/metrics/wasserstein.py +++ b/src/lightning_trainable/metrics/wasserstein.py @@ -1,9 +1,11 @@ +import warnings + import torch from torch import Tensor -def sinkhorn(a: Tensor, b: Tensor, cost: Tensor, epsilon: float = 0.1, steps: int = 100) -> Tensor: +def sinkhorn(a: Tensor, b: Tensor, cost: Tensor, epsilon: float, steps: int = 100) -> Tensor: """ Computes the Sinkhorn optimal transport plan from sample weights of two distributions. @param a: Sample weights from the first distribution in shape (n,) @@ -17,6 +19,10 @@ def sinkhorn(a: Tensor, b: Tensor, cost: Tensor, epsilon: float = 0.1, steps: in gain = torch.exp(-cost / epsilon) + if gain.mean() < 1e-30: + warnings.warn(f"Detected low bandwidth ({epsilon:.1e}) relative to cost ({cost.mean().item():.1e}). " + f"You may experience numerical instabilities. Consider increasing epsilon.") + # Initialize the dual variables. u = torch.ones(len(a)) v = torch.ones(len(b)) @@ -30,7 +36,7 @@ def sinkhorn(a: Tensor, b: Tensor, cost: Tensor, epsilon: float = 0.1, steps: in return u[:, None] * gain * v[None, :] -def sinkhorn_auto(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = 0.1, steps: int = 100) -> Tensor: +def sinkhorn_auto(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = None, steps: int = 100) -> Tensor: """ Computes the Sinkhorn optimal transport plan from samples from two distributions. See also: sinkhorn @@ -48,6 +54,9 @@ def sinkhorn_auto(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = 0. cost = torch.flatten(cost, start_dim=2) cost = torch.linalg.norm(cost, dim=-1) + if epsilon is None: + epsilon = torch.quantile(cost, 0.1).item() + # Initialize the sample weights. a = torch.ones(len(x)) / len(x) b = torch.ones(len(y)) / len(y) From c1a1a14e507b7dd71dcf7f13827fc5ce6a37ea07 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sat, 13 May 2023 01:01:07 +0200 Subject: [PATCH 15/53] fix default bandwidth --- src/lightning_trainable/metrics/wasserstein.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_trainable/metrics/wasserstein.py b/src/lightning_trainable/metrics/wasserstein.py index ce67cea..b0579c0 100644 --- a/src/lightning_trainable/metrics/wasserstein.py +++ b/src/lightning_trainable/metrics/wasserstein.py @@ -55,7 +55,7 @@ def sinkhorn_auto(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = No cost = torch.linalg.norm(cost, dim=-1) if epsilon is None: - epsilon = torch.quantile(cost, 0.1).item() + epsilon = 0.5 * cost.mean() ** 2 / x[0].numel() # Initialize the sample weights. a = torch.ones(len(x)) / len(x) From afb1aa63b2a7f3c6b0d93da00a122c70f03cb331 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sat, 13 May 2023 01:02:32 +0200 Subject: [PATCH 16/53] add comment why default bandwidth is better --- src/lightning_trainable/metrics/wasserstein.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning_trainable/metrics/wasserstein.py b/src/lightning_trainable/metrics/wasserstein.py index b0579c0..c48c232 100644 --- a/src/lightning_trainable/metrics/wasserstein.py +++ b/src/lightning_trainable/metrics/wasserstein.py @@ -55,6 +55,7 @@ def sinkhorn_auto(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = No cost = torch.linalg.norm(cost, dim=-1) if epsilon is None: + # this converges to 1 for normally distributed data epsilon = 0.5 * cost.mean() ** 2 / x[0].numel() # Initialize the sample weights. From c3fd80aacb41f2e450dc7dce3ecaf9ecc1cf8768 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sat, 13 May 2023 01:07:48 +0200 Subject: [PATCH 17/53] Turns out, the old epsilon was better (of course) --- src/lightning_trainable/metrics/wasserstein.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/lightning_trainable/metrics/wasserstein.py b/src/lightning_trainable/metrics/wasserstein.py index c48c232..1fce37e 100644 --- a/src/lightning_trainable/metrics/wasserstein.py +++ b/src/lightning_trainable/metrics/wasserstein.py @@ -42,8 +42,10 @@ def sinkhorn_auto(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = No See also: sinkhorn @param x: Samples from the first distribution in shape (n, ...). @param y: Samples from the second distribution in shape (m, ...). - @param cost: Optional cost matrix in shape (n, m). If not provided, the Euclidean distance is used. - @param epsilon: Entropic regularization parameter. + @param cost: Optional cost matrix in shape (n, m). + If not provided, the Euclidean distance is used. + @param epsilon: Optional entropic regularization parameter. + If not provided, the half-mean of the cost matrix is used. @param steps: Number of iterations. """ if x.shape[1:] != y.shape[1:]: @@ -55,8 +57,7 @@ def sinkhorn_auto(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = No cost = torch.linalg.norm(cost, dim=-1) if epsilon is None: - # this converges to 1 for normally distributed data - epsilon = 0.5 * cost.mean() ** 2 / x[0].numel() + epsilon = cost.mean() / 2 # Initialize the sample weights. a = torch.ones(len(x)) / len(x) From 6dbd49971328726533448e18c3d2c2e1316cbc4e Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sat, 13 May 2023 01:14:38 +0200 Subject: [PATCH 18/53] lower default number of iterations --- src/lightning_trainable/metrics/wasserstein.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning_trainable/metrics/wasserstein.py b/src/lightning_trainable/metrics/wasserstein.py index 1fce37e..821ef27 100644 --- a/src/lightning_trainable/metrics/wasserstein.py +++ b/src/lightning_trainable/metrics/wasserstein.py @@ -5,7 +5,7 @@ from torch import Tensor -def sinkhorn(a: Tensor, b: Tensor, cost: Tensor, epsilon: float, steps: int = 100) -> Tensor: +def sinkhorn(a: Tensor, b: Tensor, cost: Tensor, epsilon: float, steps: int = 10) -> Tensor: """ Computes the Sinkhorn optimal transport plan from sample weights of two distributions. @param a: Sample weights from the first distribution in shape (n,) @@ -36,7 +36,7 @@ def sinkhorn(a: Tensor, b: Tensor, cost: Tensor, epsilon: float, steps: int = 10 return u[:, None] * gain * v[None, :] -def sinkhorn_auto(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = None, steps: int = 100) -> Tensor: +def sinkhorn_auto(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = None, steps: int = 10) -> Tensor: """ Computes the Sinkhorn optimal transport plan from samples from two distributions. See also: sinkhorn @@ -66,7 +66,7 @@ def sinkhorn_auto(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = No return sinkhorn(a, b, cost, epsilon, steps) -def wasserstein(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = 0.1, steps: int = 100) -> Tensor: +def wasserstein(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = 0.1, steps: int = 10) -> Tensor: """ Computes the Wasserstein distance between two distributions. See also: sinkhorn_auto From d48738b6e1a17af3bd4e051b2bf4ab2c3854e823 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sat, 13 May 2023 02:23:12 +0200 Subject: [PATCH 19/53] dtype inference for sinkhorn_auto --- src/lightning_trainable/metrics/wasserstein.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/lightning_trainable/metrics/wasserstein.py b/src/lightning_trainable/metrics/wasserstein.py index 821ef27..5e4aa39 100644 --- a/src/lightning_trainable/metrics/wasserstein.py +++ b/src/lightning_trainable/metrics/wasserstein.py @@ -24,8 +24,8 @@ def sinkhorn(a: Tensor, b: Tensor, cost: Tensor, epsilon: float, steps: int = 10 f"You may experience numerical instabilities. Consider increasing epsilon.") # Initialize the dual variables. - u = torch.ones(len(a)) - v = torch.ones(len(b)) + u = torch.ones(len(a), dtype=a.dtype, device=a.device) + v = torch.ones(len(b), dtype=b.dtype, device=b.device) # Compute the Sinkhorn iterations. for _ in range(steps): @@ -60,8 +60,8 @@ def sinkhorn_auto(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = No epsilon = cost.mean() / 2 # Initialize the sample weights. - a = torch.ones(len(x)) / len(x) - b = torch.ones(len(y)) / len(y) + a = torch.ones(len(x), dtype=x.dtype, device=x.device) / len(x) + b = torch.ones(len(y), dtype=y.dtype, device=y.device) / len(y) return sinkhorn(a, b, cost, epsilon, steps) @@ -71,4 +71,5 @@ def wasserstein(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = 0.1, Computes the Wasserstein distance between two distributions. See also: sinkhorn_auto """ + # TODO: fix for cost = None return torch.sum(sinkhorn_auto(x, y, cost, epsilon, steps) * cost) From 120e87398e70aa7f3621b2e9c236ee81c764fe97 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 16 May 2023 13:14:55 +0200 Subject: [PATCH 20/53] add log-space sinkhorn --- .../metrics/wasserstein.py | 38 ++++++++++++++++++- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/src/lightning_trainable/metrics/wasserstein.py b/src/lightning_trainable/metrics/wasserstein.py index 5e4aa39..0b767ce 100644 --- a/src/lightning_trainable/metrics/wasserstein.py +++ b/src/lightning_trainable/metrics/wasserstein.py @@ -8,6 +8,8 @@ def sinkhorn(a: Tensor, b: Tensor, cost: Tensor, epsilon: float, steps: int = 10) -> Tensor: """ Computes the Sinkhorn optimal transport plan from sample weights of two distributions. + This version does not use log-space computations, but allows for zero or negative weights. + @param a: Sample weights from the first distribution in shape (n,) @param b: Sample weights from the second distribution in shape (m,) @param cost: Cost matrix in shape (n, m). @@ -36,10 +38,42 @@ def sinkhorn(a: Tensor, b: Tensor, cost: Tensor, epsilon: float, steps: int = 10 return u[:, None] * gain * v[None, :] +def sinkhorn_log(a: Tensor, b: Tensor, cost: Tensor, epsilon: float, steps: int = 10) -> Tensor: + """ + Computes the Sinkhorn optimal transport plan from sample weights of two distributions. + This version uses log-space computations to avoid numerical instabilities, but disallows zero or negative weights. + + @param a: Sample weights from the first distribution in shape (n,) + @param b: Sample weights from the second distribution in shape (m,) + @param cost: Cost matrix in shape (n, m). + @param epsilon: Entropic regularization parameter. + @param steps: Number of iterations. + """ + if torch.any(a <= 0) or torch.any(b <= 0): + raise ValueError("Expected sample weights to be non-negative.") + if cost.shape != (len(a), len(b)): + raise ValueError(f"Expected cost to have shape {(len(a), len(b))}, but got {cost.shape}.") + + log_gain = -cost / epsilon + + # Initialize the dual variables. + log_u = torch.zeros(len(a), dtype=a.dtype, device=a.device) + log_v = torch.zeros(len(b), dtype=b.dtype, device=b.device) + + # Compute the Sinkhorn iterations. + for _ in range(steps): + log_v = torch.log(b) - torch.logsumexp(log_gain + log_u[:, None], dim=0) + log_u = torch.log(a) - torch.logsumexp(log_gain + log_v[None, :], dim=1) + + # Return the transport plan. + return torch.exp(log_u[:, None] + log_gain + log_v[None, :]) + + def sinkhorn_auto(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = None, steps: int = 10) -> Tensor: """ Computes the Sinkhorn optimal transport plan from samples from two distributions. - See also: sinkhorn + See also: sinkhorn_log + @param x: Samples from the first distribution in shape (n, ...). @param y: Samples from the second distribution in shape (m, ...). @param cost: Optional cost matrix in shape (n, m). @@ -63,7 +97,7 @@ def sinkhorn_auto(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = No a = torch.ones(len(x), dtype=x.dtype, device=x.device) / len(x) b = torch.ones(len(y), dtype=y.dtype, device=y.device) / len(y) - return sinkhorn(a, b, cost, epsilon, steps) + return sinkhorn_log(a, b, cost, epsilon, steps) def wasserstein(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = 0.1, steps: int = 10) -> Tensor: From e048d2cc1bef7b2e16174ec34e9b3918e98f4857 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Tue, 16 May 2023 13:22:46 +0200 Subject: [PATCH 21/53] improve warning for low bandwidth --- src/lightning_trainable/metrics/wasserstein.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/lightning_trainable/metrics/wasserstein.py b/src/lightning_trainable/metrics/wasserstein.py index 0b767ce..e5dcb44 100644 --- a/src/lightning_trainable/metrics/wasserstein.py +++ b/src/lightning_trainable/metrics/wasserstein.py @@ -23,7 +23,7 @@ def sinkhorn(a: Tensor, b: Tensor, cost: Tensor, epsilon: float, steps: int = 10 if gain.mean() < 1e-30: warnings.warn(f"Detected low bandwidth ({epsilon:.1e}) relative to cost ({cost.mean().item():.1e}). " - f"You may experience numerical instabilities. Consider increasing epsilon.") + f"You may experience numerical instabilities. Consider increasing epsilon or using sinkhorn_log.") # Initialize the dual variables. u = torch.ones(len(a), dtype=a.dtype, device=a.device) @@ -105,5 +105,9 @@ def wasserstein(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = 0.1, Computes the Wasserstein distance between two distributions. See also: sinkhorn_auto """ - # TODO: fix for cost = None + if cost is None: + cost = x[:, None] - y[None, :] + cost = torch.flatten(cost, start_dim=2) + cost = torch.linalg.norm(cost, dim=-1) + return torch.sum(sinkhorn_auto(x, y, cost, epsilon, steps) * cost) From fa2c3df53c8285a12d648f04b010fbae149ccf4b Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Wed, 14 Jun 2023 14:33:55 +0200 Subject: [PATCH 22/53] Add nicer Choice and Range print --- src/lightning_trainable/hparams/types/choice.py | 3 +++ src/lightning_trainable/hparams/types/range.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/lightning_trainable/hparams/types/choice.py b/src/lightning_trainable/hparams/types/choice.py index 4c0c545..00315ca 100644 --- a/src/lightning_trainable/hparams/types/choice.py +++ b/src/lightning_trainable/hparams/types/choice.py @@ -11,6 +11,9 @@ def __call__(cls, *choices): namespace = {"choices": choices} return type(name, bases, namespace) + def __repr__(cls): + return f"Choice{cls.choices!r}" + class Choice(metaclass=ChoiceMeta): """ diff --git a/src/lightning_trainable/hparams/types/range.py b/src/lightning_trainable/hparams/types/range.py index b982b93..a20b739 100644 --- a/src/lightning_trainable/hparams/types/range.py +++ b/src/lightning_trainable/hparams/types/range.py @@ -21,6 +21,9 @@ def __call__(cls, lower: float | int, upper: float | int, exclude: str | None = namespace = {"lower": lower, "upper": upper, "exclude": exclude} return type(name, bases, namespace) + def __repr__(self): + return f"Range({self.lower!r}, {self.upper!r}, exclude={self.exclude!r})" + class Range(metaclass=RangeMeta): """ From a673880d81c657c7f3f1f403d3626ddd7e69375c Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Wed, 14 Jun 2023 14:34:15 +0200 Subject: [PATCH 23/53] Add norm and dropout to FullyConnectedNetwork --- .../modules/fully_connected/hparams.py | 3 + .../modules/fully_connected/network.py | 61 ++++++++++++++++--- 2 files changed, 54 insertions(+), 10 deletions(-) diff --git a/src/lightning_trainable/modules/fully_connected/hparams.py b/src/lightning_trainable/modules/fully_connected/hparams.py index 2d3dccb..7a981d1 100644 --- a/src/lightning_trainable/modules/fully_connected/hparams.py +++ b/src/lightning_trainable/modules/fully_connected/hparams.py @@ -8,3 +8,6 @@ class FullyConnectedNetworkHParams(HParams): layer_widths: list[int] activation: str = "relu" + + norm: Choice("none", "batch", "layer") = "none" + dropout: float = 0.0 diff --git a/src/lightning_trainable/modules/fully_connected/network.py b/src/lightning_trainable/modules/fully_connected/network.py index d1b0583..bb9ecf3 100644 --- a/src/lightning_trainable/modules/fully_connected/network.py +++ b/src/lightning_trainable/modules/fully_connected/network.py @@ -23,18 +23,59 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.network(x) def configure_network(self): - layers = [] - widths = [self.hparams.input_dims, *self.hparams.layer_widths, self.hparams.output_dims] + widths = self.hparams.layer_widths - if self.hparams.input_dims == "lazy": - widths = widths[1:] - layers = [nn.LazyLinear(widths[0]), get_activation(self.hparams.activation)(inplace=True)] + # input layer + input_linear = self.configure_linear(self.hparams.input_dims, widths[0]) + input_activation = self.configure_activation() + layers = [input_linear, input_activation] - for (w1, w2) in zip(widths[:-1], widths[1:]): - layers.append(nn.Linear(w1, w2)) - layers.append(get_activation(self.hparams.activation)(inplace=True)) + # hidden layers + for (in_features, out_features) in zip(widths[:-1], widths[1:]): + dropout = self.configure_dropout() + norm = self.configure_norm(in_features) - # remove last activation - layers = layers[:-1] + activation = self.configure_activation() + linear = self.configure_linear(in_features, out_features) + + if dropout is not None: + layers.append(dropout) + + if norm is not None: + layers.append(norm) + + layers.append(linear) + layers.append(activation) + + # output layer + output_linear = self.configure_linear(widths[-1], self.hparams.output_dims) + layers.append(output_linear) return nn.Sequential(*layers) + + def configure_linear(self, in_features, out_features) -> nn.Module: + match in_features: + case "lazy": + return nn.LazyLinear(out_features) + case int() as in_features: + return nn.Linear(in_features, out_features) + case other: + raise NotImplementedError(f"Unrecognized input_dims value: '{other}'") + + def configure_activation(self) -> nn.Module: + return get_activation(self.hparams.activation)(inplace=True) + + def configure_dropout(self) -> nn.Module | None: + if self.hparams.dropout > 0: + return nn.Dropout(self.hparams.dropout) + + def configure_norm(self, num_features) -> nn.Module | None: + match self.hparams.norm: + case "none": + return None + case "batch": + return nn.BatchNorm1d(num_features) + case "layer": + return nn.LayerNorm(num_features) + case other: + raise NotImplementedError(f"Unrecognized norm value: '{other}'") From 9f99bd5c1532243a13b5fcadb097fc35901e98b3 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Wed, 14 Jun 2023 14:34:26 +0200 Subject: [PATCH 24/53] Fix deprecate decorator --- src/lightning_trainable/utils/deprecate.py | 31 ++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/src/lightning_trainable/utils/deprecate.py b/src/lightning_trainable/utils/deprecate.py index bc58fa9..ff592d4 100644 --- a/src/lightning_trainable/utils/deprecate.py +++ b/src/lightning_trainable/utils/deprecate.py @@ -1,6 +1,33 @@ import warnings +from functools import wraps +import inspect -def deprecate(message: str): - warnings.warn(message, DeprecationWarning, stacklevel=2) +def deprecate(message: str, version: str = None): + + def wrapper(obj): + nonlocal message + + if inspect.isclass(obj): + name = "Class" + elif inspect.isfunction(obj): + name = "Function" + elif inspect.ismethod(obj): + name = "Method" + else: + name = "Object" + + if version is not None: + message = f"{name} '{obj.__name__}' is deprecated since version {version}: {message}" + else: + message = f"{name} '{obj.__name__}' is deprecated: {message}" + + @wraps(obj) + def wrapped(*args, **kwargs): + warnings.warn(message, DeprecationWarning, stacklevel=2) + return obj(*args, **kwargs) + + return wrapped + + return wrapper From df2f18d73b70a332be824d443928da9c40b30e74 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Wed, 14 Jun 2023 14:35:10 +0200 Subject: [PATCH 25/53] Deprecate and Disable Progress Bar due to bugs with multi-gpu and train restarts (#8) --- .../callbacks/epoch_progress_bar.py | 6 ++++++ src/lightning_trainable/trainable/trainable.py | 13 +++++-------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/lightning_trainable/callbacks/epoch_progress_bar.py b/src/lightning_trainable/callbacks/epoch_progress_bar.py index 9dd973f..a375c97 100644 --- a/src/lightning_trainable/callbacks/epoch_progress_bar.py +++ b/src/lightning_trainable/callbacks/epoch_progress_bar.py @@ -2,7 +2,11 @@ from lightning.pytorch.callbacks import ProgressBar from lightning.pytorch.callbacks.progress.tqdm_progress import Tqdm +from lightning_trainable.utils import deprecate + +@deprecate("EpochProgressBar causes issues when continuing training or using multi-GPU. " + "Use the default Lightning ProgressBar instead.") class EpochProgressBar(ProgressBar): def __init__(self): super().__init__() @@ -25,6 +29,8 @@ def on_train_end(self, trainer, pl_module): self.bar.close() +@deprecate("StepProgressBar causes issues when continuing training or using multi-GPU. " + "Use the default Lightning ProgressBar instead.") class StepProgressBar(ProgressBar): def __init__(self): super().__init__() diff --git a/src/lightning_trainable/trainable/trainable.py b/src/lightning_trainable/trainable/trainable.py index d23e5ab..e12ad67 100644 --- a/src/lightning_trainable/trainable/trainable.py +++ b/src/lightning_trainable/trainable/trainable.py @@ -4,13 +4,12 @@ import lightning import torch -from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, ProgressBar, EarlyStopping +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping from lightning.pytorch.loggers import Logger, TensorBoardLogger from torch.utils.data import DataLoader, Dataset, IterableDataset from tqdm import tqdm from lightning_trainable import utils -from lightning_trainable.callbacks import EpochProgressBar from .trainable_hparams import TrainableHParams from . import lr_schedulers @@ -42,7 +41,10 @@ def __init__( self.test_data = test_data def __init_subclass__(cls, **kwargs): - cls.hparams_type = cls.__annotations__.get("hparams", TrainableHParams) + hparams_type = cls.__annotations__.get("hparams") + if hparams_type is not None: + # only overwrite hparams_type if it is defined by the child class + cls.hparams_type = hparams_type def compute_metrics(self, batch, batch_idx) -> dict: """ @@ -131,7 +133,6 @@ def configure_callbacks(self) -> list: save_top_k=5 ), LearningRateMonitor(), - EpochProgressBar(), ] if self.hparams.early_stopping is not None: callbacks.append(EarlyStopping(monitor, patience=self.hparams.early_stopping)) @@ -216,10 +217,6 @@ def configure_trainer(self, logger_kwargs: dict = None, trainer_kwargs: dict = N if trainer_kwargs is None: trainer_kwargs = dict() - if "enable_progress_bar" not in trainer_kwargs: - if any(isinstance(callback, ProgressBar) for callback in self.configure_callbacks()): - trainer_kwargs["enable_progress_bar"] = False - return lightning.Trainer( accelerator=self.hparams.accelerator.lower(), logger=self.configure_logger(**logger_kwargs), From 78cf42f46489eeeca4fbe8861467c08ce902445b Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Wed, 14 Jun 2023 14:37:44 +0200 Subject: [PATCH 26/53] Remove optimize_hparams --- .../trainable/trainable.py | 55 ------------------- 1 file changed, 55 deletions(-) diff --git a/src/lightning_trainable/trainable/trainable.py b/src/lightning_trainable/trainable/trainable.py index e12ad67..9654679 100644 --- a/src/lightning_trainable/trainable/trainable.py +++ b/src/lightning_trainable/trainable/trainable.py @@ -311,61 +311,6 @@ def load_checkpoint(cls, root: str | pathlib.Path = "lightning_logs", version: i checkpoint = utils.find_checkpoint(root, version, epoch, step) return cls.load_from_checkpoint(checkpoint, **kwargs) - @classmethod - def optimize_hparams(cls, hparams: dict, scheduler: str = "asha", model_kwargs: dict = None, tune_kwargs: dict = None): - """ Optimize the HParams with ray[tune] """ - try: - from ray import tune - import os - except ModuleNotFoundError: - raise ModuleNotFoundError(f"Please install lightning-trainable[experiments] to use `optimize_hparams`.") - - if model_kwargs is None: - model_kwargs = dict() - if tune_kwargs is None: - tune_kwargs = dict() - - path = os.getcwd() - - def train(hparams): - os.chdir(path) - model = cls(hparams, **model_kwargs) - model.fit( - logger_kwargs=dict(save_dir=str(tune.get_trial_dir())), - trainer_kwargs=dict(enable_progress_bar=False) - ) - - return model - - match scheduler: - case "asha": - scheduler = tune.schedulers.AsyncHyperBandScheduler( - time_attr="time_total_s", - max_t=600, - grace_period=180, - reduction_factor=4, - ) - case other: - raise NotImplementedError(f"Unrecognized scheduler: '{other}'") - - reporter = tune.JupyterNotebookReporter( - overwrite=True, - parameter_columns=list(hparams.keys()), - metric_columns=[f"training/{cls.hparams_type.loss}", f"validation/{cls.hparams_type.loss}"], - ) - - analysis = tune.run( - train, - metric=f"validation/{cls.hparams_type.loss}", - mode="min", - config=hparams, - scheduler=scheduler, - progress_reporter=reporter, - **tune_kwargs - ) - - return analysis - def auto_pin_memory(pin_memory: bool | None, accelerator: str): if pin_memory is None: From 8f0f51927d9ec0c6655d0eff19f902695a5bb4fb Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Thu, 22 Jun 2023 17:52:50 +0200 Subject: [PATCH 27/53] Add checkpoint loading to test_simple_model --- tests/test_trainable.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_trainable.py b/tests/test_trainable.py index 41ade21..b6e7952 100644 --- a/tests/test_trainable.py +++ b/tests/test_trainable.py @@ -35,6 +35,8 @@ def compute_metrics(self, batch, batch_idx) -> dict: model = SimpleTrainable(hparams, train_data=train_data) model.fit() + model.load_checkpoint() + def test_double_train(): class SimpleTrainable(Trainable): From 3f9afcb78f9efb366503dffe0370e1efe072d56f Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Thu, 22 Jun 2023 17:52:58 +0200 Subject: [PATCH 28/53] add DistributionDataset to imports --- src/lightning_trainable/datasets/core/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning_trainable/datasets/core/__init__.py b/src/lightning_trainable/datasets/core/__init__.py index a0f49bf..6e526bd 100644 --- a/src/lightning_trainable/datasets/core/__init__.py +++ b/src/lightning_trainable/datasets/core/__init__.py @@ -1 +1,2 @@ +from .distribution_dataset import DistributionDataset from .joint import JointDataset, JointIterableDataset From fe74af52b49526c6fc9ee110efebcc13f46da256 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Thu, 6 Jul 2023 12:03:12 +0200 Subject: [PATCH 29/53] Fix overwrite hparams_type in HParamsModule --- src/lightning_trainable/modules/hparams_module.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/lightning_trainable/modules/hparams_module.py b/src/lightning_trainable/modules/hparams_module.py index 225ce17..d8a0691 100644 --- a/src/lightning_trainable/modules/hparams_module.py +++ b/src/lightning_trainable/modules/hparams_module.py @@ -17,4 +17,7 @@ def __init__(self, hparams: HParams | dict): self.hparams = hparams def __init_subclass__(cls, **kwargs): - cls.hparams_type = cls.__annotations__["hparams"] + hparams_type = cls.__annotations__.get("hparams") + if hparams_type is not None: + # only overwrite hparams_type if it is defined by the child class + cls.hparams_type = hparams_type From d4a2f0f34f17167cb25c3acb24e337cf5a7ba213 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Thu, 7 Sep 2023 20:34:27 +0200 Subject: [PATCH 30/53] Simple UNet, stabilized metrics --- src/lightning_trainable/metrics/__init__.py | 1 + src/lightning_trainable/metrics/sinkhorn.py | 107 ++++++++++++++++++ .../metrics/wasserstein.py | 97 +--------------- src/lightning_trainable/modules/__init__.py | 2 +- .../modules/simple_unet/__init__.py | 3 + .../modules/simple_unet/down_block.py | 29 +++++ .../modules/simple_unet/hparams.py | 23 ++++ .../modules/simple_unet/network.py | 73 ++++++++++++ .../modules/simple_unet/up_block.py | 29 +++++ .../modules/unet/network.py | 47 +++++++- .../trainable/trainable_hparams.py | 3 - 11 files changed, 312 insertions(+), 102 deletions(-) create mode 100644 src/lightning_trainable/metrics/sinkhorn.py create mode 100644 src/lightning_trainable/modules/simple_unet/__init__.py create mode 100644 src/lightning_trainable/modules/simple_unet/down_block.py create mode 100644 src/lightning_trainable/modules/simple_unet/hparams.py create mode 100644 src/lightning_trainable/modules/simple_unet/network.py create mode 100644 src/lightning_trainable/modules/simple_unet/up_block.py diff --git a/src/lightning_trainable/metrics/__init__.py b/src/lightning_trainable/metrics/__init__.py index d94ccf1..30bcacd 100644 --- a/src/lightning_trainable/metrics/__init__.py +++ b/src/lightning_trainable/metrics/__init__.py @@ -1,3 +1,4 @@ from .accuracy import accuracy from .error import error +from .sinkhorn import sinkhorn_auto as sinkhorn from .wasserstein import wasserstein diff --git a/src/lightning_trainable/metrics/sinkhorn.py b/src/lightning_trainable/metrics/sinkhorn.py new file mode 100644 index 0000000..623ecdc --- /dev/null +++ b/src/lightning_trainable/metrics/sinkhorn.py @@ -0,0 +1,107 @@ +import warnings + +import torch +from torch import Tensor + +import torch.nn.functional as F +import numpy as np + + +def sinkhorn(a: Tensor, b: Tensor, cost: Tensor, epsilon: float, steps: int = 1000) -> Tensor: + """ + Computes the Sinkhorn optimal transport plan from sample weights of two distributions. + This version does not use log-space computations, but allows for zero or negative weights. + + @param a: Sample weights from the first distribution in shape (n,) + @param b: Sample weights from the second distribution in shape (m,) + @param cost: Cost matrix in shape (n, m). + @param epsilon: Entropic regularization parameter. + @param steps: Number of iterations. + """ + if cost.shape != (len(a), len(b)): + raise ValueError(f"Expected cost to have shape {(len(a), len(b))}, but got {cost.shape}.") + + gain = torch.exp(-cost / epsilon) + + if gain.mean() < 1e-30: + warnings.warn(f"Detected low bandwidth ({epsilon:.1e}) relative to cost ({cost.mean().item():.1e}). " + f"You may experience numerical instabilities. Consider increasing epsilon or using sinkhorn_log.") + + # Initialize the dual variables. + u = torch.ones(len(a), dtype=a.dtype, device=a.device) + v = torch.ones(len(b), dtype=b.dtype, device=b.device) + + # Compute the Sinkhorn iterations. + for _ in range(steps): + v = b / (torch.matmul(gain.T, u) + 1e-50) + u = a / (torch.matmul(gain, v) + 1e-50) + + # Return the transport plan. + return u[:, None] * gain * v[None, :] + + +def sinkhorn_log(log_a: Tensor, log_b: Tensor, cost: Tensor, epsilon: float, steps: int = 1000) -> Tensor: + """ + Computes the Sinkhorn optimal transport plan from sample weights of two distributions. + This version uses log-space computations to avoid numerical instabilities, but disallows zero or negative weights. + + @param log_a: Log sample weights from the first distribution in shape (n,) + @param log_b: Log sample weights from the second distribution in shape (m,) + @param cost: Cost matrix in shape (n, m). + @param epsilon: Entropic regularization parameter. + @param steps: Number of iterations. + """ + if cost.shape != (len(log_a), len(log_b)): + raise ValueError(f"Expected cost to have shape {(len(log_a), len(log_b))}, but got {cost.shape}.") + + log_gain = -cost / epsilon + + # Initialize the dual variables. + log_u = torch.zeros(len(log_a), dtype=log_a.dtype, device=log_a.device) + log_v = torch.zeros(len(log_b), dtype=log_b.dtype, device=log_b.device) + + # Compute the Sinkhorn iterations. + for _ in range(steps): + log_v = log_b - torch.logsumexp(log_gain + log_u[:, None], dim=0) + log_u = log_a - torch.logsumexp(log_gain + log_v[None, :], dim=1) + + plan = torch.exp(log_u[:, None] + log_gain + log_v[None, :]) + + if not torch.allclose(len(log_b) * plan.sum(dim=0), torch.ones(len(log_b), device=plan.device)) or not torch.allclose(len(log_a) * plan.sum(dim=1), torch.ones(len(log_a), device=plan.device)): + warnings.warn(f"Sinkhorn did not converge. Consider increasing epsilon or number of iterations.") + + # Return the transport plan. + return plan + + +def sinkhorn_auto(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = 1.0, steps: int = 1000) -> Tensor: + """ + Computes the Sinkhorn optimal transport plan from samples from two distributions. + See also: sinkhorn_log + + @param x: Samples from the first distribution in shape (n, ...). + @param y: Samples from the second distribution in shape (m, ...). + @param cost: Optional cost matrix in shape (n, m). + If not provided, the Euclidean distance is used. + @param epsilon: Optional entropic regularization parameter. + This parameter is normalized to the half-mean of the cost matrix. This helps keep the value independent + of the data dimensionality. Note that this behaviour is exclusive to this method; sinkhorn_log only accepts + the raw entropic regularization value. + @param steps: Number of iterations. + """ + if x.shape[1:] != y.shape[1:]: + raise ValueError(f"Expected x and y to live in the same feature space, " + f"but got {x.shape[1:]} and {y.shape[1:]}.") + if cost is None: + cost = x[:, None] - y[None, :] + cost = torch.flatten(cost, start_dim=2) + cost = torch.linalg.norm(cost, dim=-1) + + # Initialize epsilon independent of the data dimension (i.e. dependent on the mean cost) + epsilon = epsilon * cost.mean() / 2 + + # Initialize the sample weights. + log_a = torch.zeros(len(x), device=x.device) - np.log(len(x)) + log_b = torch.zeros(len(y), device=y.device) - np.log(len(y)) + + return sinkhorn_log(log_a, log_b, cost, epsilon, steps) diff --git a/src/lightning_trainable/metrics/wasserstein.py b/src/lightning_trainable/metrics/wasserstein.py index e5dcb44..1fdb447 100644 --- a/src/lightning_trainable/metrics/wasserstein.py +++ b/src/lightning_trainable/metrics/wasserstein.py @@ -1,103 +1,8 @@ -import warnings - import torch from torch import Tensor - -def sinkhorn(a: Tensor, b: Tensor, cost: Tensor, epsilon: float, steps: int = 10) -> Tensor: - """ - Computes the Sinkhorn optimal transport plan from sample weights of two distributions. - This version does not use log-space computations, but allows for zero or negative weights. - - @param a: Sample weights from the first distribution in shape (n,) - @param b: Sample weights from the second distribution in shape (m,) - @param cost: Cost matrix in shape (n, m). - @param epsilon: Entropic regularization parameter. - @param steps: Number of iterations. - """ - if cost.shape != (len(a), len(b)): - raise ValueError(f"Expected cost to have shape {(len(a), len(b))}, but got {cost.shape}.") - - gain = torch.exp(-cost / epsilon) - - if gain.mean() < 1e-30: - warnings.warn(f"Detected low bandwidth ({epsilon:.1e}) relative to cost ({cost.mean().item():.1e}). " - f"You may experience numerical instabilities. Consider increasing epsilon or using sinkhorn_log.") - - # Initialize the dual variables. - u = torch.ones(len(a), dtype=a.dtype, device=a.device) - v = torch.ones(len(b), dtype=b.dtype, device=b.device) - - # Compute the Sinkhorn iterations. - for _ in range(steps): - v = b / (torch.matmul(gain.T, u) + 1e-50) - u = a / (torch.matmul(gain, v) + 1e-50) - - # Return the transport plan. - return u[:, None] * gain * v[None, :] - - -def sinkhorn_log(a: Tensor, b: Tensor, cost: Tensor, epsilon: float, steps: int = 10) -> Tensor: - """ - Computes the Sinkhorn optimal transport plan from sample weights of two distributions. - This version uses log-space computations to avoid numerical instabilities, but disallows zero or negative weights. - - @param a: Sample weights from the first distribution in shape (n,) - @param b: Sample weights from the second distribution in shape (m,) - @param cost: Cost matrix in shape (n, m). - @param epsilon: Entropic regularization parameter. - @param steps: Number of iterations. - """ - if torch.any(a <= 0) or torch.any(b <= 0): - raise ValueError("Expected sample weights to be non-negative.") - if cost.shape != (len(a), len(b)): - raise ValueError(f"Expected cost to have shape {(len(a), len(b))}, but got {cost.shape}.") - - log_gain = -cost / epsilon - - # Initialize the dual variables. - log_u = torch.zeros(len(a), dtype=a.dtype, device=a.device) - log_v = torch.zeros(len(b), dtype=b.dtype, device=b.device) - - # Compute the Sinkhorn iterations. - for _ in range(steps): - log_v = torch.log(b) - torch.logsumexp(log_gain + log_u[:, None], dim=0) - log_u = torch.log(a) - torch.logsumexp(log_gain + log_v[None, :], dim=1) - - # Return the transport plan. - return torch.exp(log_u[:, None] + log_gain + log_v[None, :]) - - -def sinkhorn_auto(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = None, steps: int = 10) -> Tensor: - """ - Computes the Sinkhorn optimal transport plan from samples from two distributions. - See also: sinkhorn_log - - @param x: Samples from the first distribution in shape (n, ...). - @param y: Samples from the second distribution in shape (m, ...). - @param cost: Optional cost matrix in shape (n, m). - If not provided, the Euclidean distance is used. - @param epsilon: Optional entropic regularization parameter. - If not provided, the half-mean of the cost matrix is used. - @param steps: Number of iterations. - """ - if x.shape[1:] != y.shape[1:]: - raise ValueError(f"Expected x and y to live in the same feature space, " - f"but got {x.shape[1:]} and {y.shape[1:]}.") - if cost is None: - cost = x[:, None] - y[None, :] - cost = torch.flatten(cost, start_dim=2) - cost = torch.linalg.norm(cost, dim=-1) - - if epsilon is None: - epsilon = cost.mean() / 2 - - # Initialize the sample weights. - a = torch.ones(len(x), dtype=x.dtype, device=x.device) / len(x) - b = torch.ones(len(y), dtype=y.dtype, device=y.device) / len(y) - - return sinkhorn_log(a, b, cost, epsilon, steps) +from .sinkhorn import sinkhorn_auto def wasserstein(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = 0.1, steps: int = 10) -> Tensor: diff --git a/src/lightning_trainable/modules/__init__.py b/src/lightning_trainable/modules/__init__.py index b5cb254..2f4a8c8 100644 --- a/src/lightning_trainable/modules/__init__.py +++ b/src/lightning_trainable/modules/__init__.py @@ -1,4 +1,4 @@ from .convolutional import ConvolutionalNetwork, ConvolutionalNetworkHParams from .fully_connected import FullyConnectedNetwork, FullyConnectedNetworkHParams from .hparams_module import HParamsModule -from .unet import UNet, UNetHParams, UNetBlockHParams +from .simple_unet import SimpleUNet, SimpleUNetHParams diff --git a/src/lightning_trainable/modules/simple_unet/__init__.py b/src/lightning_trainable/modules/simple_unet/__init__.py new file mode 100644 index 0000000..64c06ad --- /dev/null +++ b/src/lightning_trainable/modules/simple_unet/__init__.py @@ -0,0 +1,3 @@ + +from .hparams import SimpleUNetHParams +from .network import SimpleUNet diff --git a/src/lightning_trainable/modules/simple_unet/down_block.py b/src/lightning_trainable/modules/simple_unet/down_block.py new file mode 100644 index 0000000..3a3cd4f --- /dev/null +++ b/src/lightning_trainable/modules/simple_unet/down_block.py @@ -0,0 +1,29 @@ + +import torch +import torch.nn as nn +from torch.nn import Module + +from lightning_trainable.utils import get_activation + + +class SimpleUNetDownBlock(Module): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int, block_size: int = 2, activation: str = "relu"): + super().__init__() + + self.channels = torch.linspace(in_channels, out_channels, block_size + 1, dtype=torch.int64).tolist() + + layers = [] + for c1, c2 in zip(self.channels[:-2], self.channels[1:-1]): + layers.append(nn.Conv2d(c1, c2, kernel_size, padding="same")) + layers.append(get_activation(activation)(inplace=True)) + + layers.append(nn.Conv2d(self.channels[-2], self.channels[-1], kernel_size, padding="same")) + layers.append(nn.MaxPool2d(2)) + + self.block = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.block(x) + + def extra_repr(self) -> str: + return f"in_channels={self.channels[0]}, out_channels={self.channels[-1]}, kernel_size={self.block[0].kernel_size[0]}, block_size={len(self.channels) - 1}" diff --git a/src/lightning_trainable/modules/simple_unet/hparams.py b/src/lightning_trainable/modules/simple_unet/hparams.py new file mode 100644 index 0000000..bb4714d --- /dev/null +++ b/src/lightning_trainable/modules/simple_unet/hparams.py @@ -0,0 +1,23 @@ +from lightning_trainable.hparams import HParams + + +class SimpleUNetHParams(HParams): + input_shape: tuple[int, int, int] + conditions: int = 0 + + channels: list[int] + kernel_sizes: list[int] + fc_widths: list[int] + activation: str = "ReLU" + + block_size: int = 2 + + @classmethod + def validate_parameters(cls, hparams): + hparams = super().validate_parameters(hparams) + + if len(hparams.channels) + 2 != len(hparams.kernel_sizes): + raise ValueError(f"Number of channels ({len(hparams.channels)}) + 2 must be equal " + f"to the number of kernel sizes ({len(hparams.kernel_sizes)})") + + return hparams diff --git a/src/lightning_trainable/modules/simple_unet/network.py b/src/lightning_trainable/modules/simple_unet/network.py new file mode 100644 index 0000000..f2fce9a --- /dev/null +++ b/src/lightning_trainable/modules/simple_unet/network.py @@ -0,0 +1,73 @@ + +import torch +import torch.nn as nn +from torch import Tensor + +from ..fully_connected import FullyConnectedNetwork +from ..hparams_module import HParamsModule + +from .down_block import SimpleUNetDownBlock +from .hparams import SimpleUNetHParams +from .up_block import SimpleUNetUpBlock + + +class SimpleUNet(HParamsModule): + hparams: SimpleUNetHParams + + def __init__(self, hparams: dict | SimpleUNetHParams): + super().__init__(hparams) + + channels = [self.hparams.input_shape[0], *self.hparams.channels] + + self.down_blocks = nn.ModuleList([ + SimpleUNetDownBlock(c1, c2, kernel_size, self.hparams.block_size, self.hparams.activation) + for c1, c2, kernel_size in zip(channels[:-1], channels[1:], self.hparams.kernel_sizes) + ]) + + fc_channels = self.hparams.channels[-1] + height = self.hparams.input_shape[1] // 2 ** (len(channels) - 1) + width = self.hparams.input_shape[2] // 2 ** (len(channels) - 1) + + fc_hparams = dict( + input_dims=self.hparams.conditions + fc_channels * height * width, + output_dims=fc_channels * height * width, + activation=self.hparams.activation, + layer_widths=self.hparams.fc_widths, + ) + self.fc = FullyConnectedNetwork(fc_hparams) + self.up_blocks = nn.ModuleList([ + SimpleUNetUpBlock(c2, c1, kernel_size, self.hparams.block_size, self.hparams.activation) + for c1, c2, kernel_size in zip(channels[:-1], channels[1:], self.hparams.kernel_sizes) + ][::-1]) + + def forward(self, image: Tensor, condition: Tensor = None) -> Tensor: + residuals = [] + for block in self.down_blocks: + image = block(image) + residuals.append(image) + + shape = image.shape + image = image.flatten(start_dim=1) + + if condition is not None: + image = torch.cat([image, condition], dim=1) + + image = self.fc(image) + + image = image.reshape(shape) + + for block in self.up_blocks: + residual = residuals.pop() + image = block(image + residual) + + return image + + def down(self, image: Tensor) -> Tensor: + for block in self.down_blocks: + image = block(image) + return image + + def up(self, image: Tensor) -> Tensor: + for block in self.up_blocks: + image = block(image) + return image diff --git a/src/lightning_trainable/modules/simple_unet/up_block.py b/src/lightning_trainable/modules/simple_unet/up_block.py new file mode 100644 index 0000000..b389626 --- /dev/null +++ b/src/lightning_trainable/modules/simple_unet/up_block.py @@ -0,0 +1,29 @@ + +import torch +import torch.nn as nn +from torch.nn import Module + +from lightning_trainable.utils import get_activation + + +class SimpleUNetUpBlock(Module): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int, block_size: int = 2, activation: str = "relu"): + super().__init__() + + self.channels = torch.linspace(in_channels, out_channels, block_size + 1, dtype=torch.int64).tolist() + + layers = [] + for c1, c2 in zip(self.channels[:-2], self.channels[1:-1]): + layers.append(nn.Conv2d(c1, c2, kernel_size, padding="same")) + layers.append(get_activation(activation)(inplace=True)) + + layers.append(nn.Conv2d(self.channels[-2], self.channels[-1], kernel_size, padding="same")) + layers.append(nn.ConvTranspose2d(self.channels[-1], self.channels[-1], 2, 2)) + + self.block = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.block(x) + + def extra_repr(self) -> str: + return f"in_channels={self.channels[0]}, out_channels={self.channels[-1]}, kernel_size={self.block[0].kernel_size[0]}, block_size={len(self.channels) - 1}" diff --git a/src/lightning_trainable/modules/unet/network.py b/src/lightning_trainable/modules/unet/network.py index f40a31a..9f04d3f 100644 --- a/src/lightning_trainable/modules/unet/network.py +++ b/src/lightning_trainable/modules/unet/network.py @@ -20,6 +20,11 @@ class UNet(HParamsModule): def __init__(self, hparams: dict | UNetHParams): super().__init__(hparams) + self.down_blocks = nn.ModuleList() + self.up_blocks = nn.ModuleList() + self.levels = nn.ModuleList() + self.fc = None + self.network = self.configure_network( input_shape=self.hparams.input_shape, output_shape=self.hparams.output_shape, @@ -35,11 +40,15 @@ def configure_levels(self, input_shape: (int, int, int), output_shape: (int, int Recursively configures the levels of the UNet. """ if not down_blocks: - return self.configure_fc(input_shape, output_shape) + self.fc = self.configure_fc(input_shape, output_shape) + return self.fc down_block = down_blocks[0] up_block = up_blocks[-1] + down_channels = [input_shape[0], *down_block["channels"]] + up_channels = [*up_block["channels"], output_shape[0]] + down_hparams = AttributeDict( channels=[input_shape[0], *down_block["channels"]], kernel_sizes=down_block["kernel_sizes"], @@ -63,11 +72,15 @@ def configure_levels(self, input_shape: (int, int, int), output_shape: (int, int next_output_shape = (up_hparams.channels[0], output_shape[1] // 2, output_shape[2] // 2) if self.hparams.skip_mode == "concat": + up_channels[0] += down_channels[-1] up_hparams.channels[0] += down_hparams.channels[-1] down_block = ConvolutionalBlock(down_hparams) up_block = ConvolutionalBlock(up_hparams) + self.down_blocks.append(down_block) + self.up_blocks.append(up_block) + next_level = self.configure_levels( input_shape=next_input_shape, output_shape=next_output_shape, @@ -75,12 +88,42 @@ def configure_levels(self, input_shape: (int, int, int), output_shape: (int, int up_blocks=up_blocks[:-1], ) - return nn.Sequential( + level = nn.Sequential( down_block, SkipConnection(next_level, mode=self.hparams.skip_mode), up_block, ) + self.levels.append(level) + + return level + + def configure_down_block(self, channels, kernel_sizes): + hparams = AttributeDict( + channels=channels, + kernel_sizes=kernel_sizes, + activation=self.hparams.activation, + padding="same", + pool=True, + pool_direction="down", + pool_position="last" + ) + + return ConvolutionalBlock(hparams) + + def configure_up_block(self, channels, kernel_sizes): + hparams = AttributeDict( + channels=channels, + kernel_sizes=kernel_sizes, + activation=self.hparams.activation, + padding="same", + pool=True, + pool_direction="up", + pool_position="first" + ) + + return ConvolutionalBlock(hparams) + def configure_fc(self, input_shape: (int, int, int), output_shape: (int, int, int)): """ Configures the lowest level of the UNet as a fully connected network. diff --git a/src/lightning_trainable/trainable/trainable_hparams.py b/src/lightning_trainable/trainable/trainable_hparams.py index dc8fc13..82f3f09 100644 --- a/src/lightning_trainable/trainable/trainable_hparams.py +++ b/src/lightning_trainable/trainable/trainable_hparams.py @@ -1,12 +1,9 @@ -import warnings - from lightning_trainable.hparams import HParams from lightning.pytorch.profilers import Profiler from lightning_trainable.utils import deprecate - class TrainableHParams(HParams): # name of the loss, your `compute_metrics` should return a dict with this name in its keys loss: str = "loss" From 758864efb8482806304c76f6f17126cab55e7aa4 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Mon, 16 Oct 2023 16:13:34 +0200 Subject: [PATCH 31/53] Improved Tests + Cleaning Up --- .../modules/convolutional/__init__.py | 5 - .../modules/convolutional/block.py | 74 --------- .../modules/convolutional/block_hparams.py | 25 --- .../modules/convolutional/hparams.py | 8 - .../modules/convolutional/network.py | 28 ---- .../modules/fully_connected/network.py | 6 +- .../modules/unet/__init__.py | 2 - .../modules/unet/hparams.py | 66 -------- .../modules/unet/network.py | 143 ------------------ .../modules/unet/skip_connection.py | 23 --- .../modules/unet/temporary_flatten.py | 14 -- .../trainable/trainable.py | 9 +- tests/module_tests/__init__.py | 0 tests/module_tests/test_trainable.py | 73 +++++++++ tests/test_trainable.py | 72 --------- 15 files changed, 85 insertions(+), 463 deletions(-) delete mode 100644 src/lightning_trainable/modules/convolutional/__init__.py delete mode 100644 src/lightning_trainable/modules/convolutional/block.py delete mode 100644 src/lightning_trainable/modules/convolutional/block_hparams.py delete mode 100644 src/lightning_trainable/modules/convolutional/hparams.py delete mode 100644 src/lightning_trainable/modules/convolutional/network.py delete mode 100644 src/lightning_trainable/modules/unet/__init__.py delete mode 100644 src/lightning_trainable/modules/unet/hparams.py delete mode 100644 src/lightning_trainable/modules/unet/network.py delete mode 100644 src/lightning_trainable/modules/unet/skip_connection.py delete mode 100644 src/lightning_trainable/modules/unet/temporary_flatten.py create mode 100644 tests/module_tests/__init__.py create mode 100644 tests/module_tests/test_trainable.py delete mode 100644 tests/test_trainable.py diff --git a/src/lightning_trainable/modules/convolutional/__init__.py b/src/lightning_trainable/modules/convolutional/__init__.py deleted file mode 100644 index 7111d1e..0000000 --- a/src/lightning_trainable/modules/convolutional/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .block import ConvolutionalBlock -from .block_hparams import ConvolutionalBlockHParams - -from .network import ConvolutionalNetwork -from .hparams import ConvolutionalNetworkHParams diff --git a/src/lightning_trainable/modules/convolutional/block.py b/src/lightning_trainable/modules/convolutional/block.py deleted file mode 100644 index e83f743..0000000 --- a/src/lightning_trainable/modules/convolutional/block.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -import torch.nn as nn - -import lightning_trainable.utils as utils - -from ..hparams_module import HParamsModule - -from .block_hparams import ConvolutionalBlockHParams - - -class ConvolutionalBlock(HParamsModule): - """ - Implements a series of convolutions, each followed by an activation function. - """ - - hparams: ConvolutionalBlockHParams - - def __init__(self, hparams: dict | ConvolutionalBlockHParams): - super().__init__(hparams) - - self.network = self.configure_network() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.network(x) - - def configure_network(self): - # construct convolutions - convolutions = [] - cck = zip(self.hparams.channels[:-1], self.hparams.channels[1:], self.hparams.kernel_sizes) - for in_channels, out_channels, kernel_size in cck: - conv = nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - padding=self.hparams.padding, - dilation=self.hparams.dilation, - groups=self.hparams.groups, - bias=self.hparams.bias, - padding_mode=self.hparams.padding_mode, - ) - - convolutions.append(conv) - - # construct activations - activations = [] - for _ in range(len(self.hparams.channels) - 2): - activations.append(utils.get_activation(self.hparams.activation)(inplace=True)) - - layers = list(utils.zip(convolutions, activations, exhaustive=True, nested=False)) - - # add pooling layer if requested - if self.hparams.pool: - match self.hparams.pool_direction: - case "up": - if self.hparams.pool_position == "first": - channels = self.hparams.channels[0] - else: - channels = self.hparams.channels[-1] - - pool = nn.ConvTranspose2d(channels, channels, 2, 2) - case "down": - pool = nn.MaxPool2d(2, 2) - case _: - raise NotImplementedError(f"Unrecognized pool direction '{self.hparams.pool_direction}'.") - - match self.hparams.pool_position: - case "first": - layers.insert(0, pool) - case "last": - layers.append(pool) - case _: - raise NotImplementedError(f"Unrecognized pool position '{self.hparams.pool_position}'.") - - return nn.Sequential(*layers) diff --git a/src/lightning_trainable/modules/convolutional/block_hparams.py b/src/lightning_trainable/modules/convolutional/block_hparams.py deleted file mode 100644 index b901ea9..0000000 --- a/src/lightning_trainable/modules/convolutional/block_hparams.py +++ /dev/null @@ -1,25 +0,0 @@ -from lightning_trainable.hparams import HParams, Choice - - -class ConvolutionalBlockHParams(HParams): - channels: list[int] - kernel_sizes: list[int] - activation: str = "relu" - padding: str | int = 0 - dilation: int = 1 - groups: int = 1 - bias: bool = True - padding_mode: str = "zeros" - - pool: bool = False - pool_direction: Choice("up", "down") = "down" - pool_position: Choice("first", "last") = "last" - - @classmethod - def validate_parameters(cls, hparams): - hparams = super().validate_parameters(hparams) - - if not len(hparams.channels) == len(hparams.kernel_sizes): - raise ValueError(f"{cls.__name__} needs same number of channels and kernel sizes.") - - return hparams diff --git a/src/lightning_trainable/modules/convolutional/hparams.py b/src/lightning_trainable/modules/convolutional/hparams.py deleted file mode 100644 index 046e54e..0000000 --- a/src/lightning_trainable/modules/convolutional/hparams.py +++ /dev/null @@ -1,8 +0,0 @@ - -from lightning_trainable.hparams import HParams - -from .block_hparams import ConvolutionalBlockHParams - - -class ConvolutionalNetworkHParams(HParams): - block_hparams: list[ConvolutionalBlockHParams] diff --git a/src/lightning_trainable/modules/convolutional/network.py b/src/lightning_trainable/modules/convolutional/network.py deleted file mode 100644 index c4ad3ad..0000000 --- a/src/lightning_trainable/modules/convolutional/network.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch -import torch.nn as nn - - -from ..hparams_module import HParamsModule -from ..sequential_mixin import SequentialMixin - -from .block import ConvolutionalBlock -from .hparams import ConvolutionalNetworkHParams - - -class ConvolutionalNetwork(SequentialMixin, HParamsModule): - """ - Implements a series of pooled convolutional blocks. - """ - hparams: ConvolutionalNetworkHParams - - def __init__(self, hparams): - super().__init__(hparams) - self.network = self.configure_network() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.network(x) - - def configure_network(self): - blocks = [ConvolutionalBlock(hparams) for hparams in self.hparams.block_hparams] - - return nn.Sequential(*blocks) diff --git a/src/lightning_trainable/modules/fully_connected/network.py b/src/lightning_trainable/modules/fully_connected/network.py index bb9ecf3..b4a133e 100644 --- a/src/lightning_trainable/modules/fully_connected/network.py +++ b/src/lightning_trainable/modules/fully_connected/network.py @@ -1,4 +1,6 @@ -import torch + +from torch import Tensor + import torch.nn as nn from lightning_trainable.utils import get_activation @@ -19,7 +21,7 @@ def __init__(self, hparams): self.network = self.configure_network() - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: return self.network(x) def configure_network(self): diff --git a/src/lightning_trainable/modules/unet/__init__.py b/src/lightning_trainable/modules/unet/__init__.py deleted file mode 100644 index 964f0cb..0000000 --- a/src/lightning_trainable/modules/unet/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .hparams import UNetHParams, UNetBlockHParams -from .network import UNet diff --git a/src/lightning_trainable/modules/unet/hparams.py b/src/lightning_trainable/modules/unet/hparams.py deleted file mode 100644 index 6afc972..0000000 --- a/src/lightning_trainable/modules/unet/hparams.py +++ /dev/null @@ -1,66 +0,0 @@ -from lightning_trainable.hparams import HParams, Choice - - -class UNetBlockHParams(HParams): - # this is a reduced variant of the ConvolutionalBlockHParams - # where additional parameters are generated automatically by UNet - channels: list[int] - kernel_sizes: list[int] - - @classmethod - def validate_parameters(cls, hparams): - hparams = super().validate_parameters(hparams) - - if len(hparams.channels) + 1 != len(hparams.kernel_sizes): - raise ValueError(f"Block needs one more kernel size than channels.") - - return hparams - - -class UNetHParams(HParams): - # input/output image size, not including batch dimensions - input_shape: tuple[int, int, int] - output_shape: tuple[int, int, int] - - # list of hparams for individual down/up blocks - down_blocks: list[dict | UNetBlockHParams] - up_blocks: list[dict | UNetBlockHParams] - - # hidden layer sizes for the bottom, fully connected part of the UNet - bottom_widths: list[int] - - # skip connection mode - skip_mode: Choice("add", "concat", "none") = "add" - activation: str = "relu" - - @classmethod - def validate_parameters(cls, hparams): - hparams = super().validate_parameters(hparams) - - for i in range(len(hparams.down_blocks)): - hparams.down_blocks[i] = UNetBlockHParams(**hparams.down_blocks[i]) - for i in range(len(hparams.up_blocks)): - hparams.up_blocks[i] = UNetBlockHParams(**hparams.up_blocks[i]) - - url = "https://github.com/LarsKue/lightning-trainable/" - if hparams.input_shape[1:] != hparams.output_shape[1:]: - raise ValueError(f"Different image sizes for input and output are not yet supported. " - f"If you need this feature, please file an issue or pull request at {url}.") - - if hparams.input_shape[1] % 2 or hparams.input_shape[2] % 2: - raise ValueError(f"Odd input shape is not yet supported. " - f"If you need this feature, please file an issue or pull request at {url}.") - - minimum_size = 2 ** len(hparams.down_blocks) - if hparams.input_shape[1] < minimum_size or hparams.input_shape[2] < minimum_size: - raise ValueError(f"Input shape {hparams.input_shape[1:]} is too small for {len(hparams.down_blocks)} " - f"down blocks. Minimum size is {(minimum_size, minimum_size)}.") - - if hparams.skip_mode == "add": - # ensure matching number of channels for down output as up input - for i, (down_block, up_block) in enumerate(zip(hparams.down_blocks, reversed(hparams.up_blocks))): - if down_block["channels"][-1] != up_block["channels"][0]: - raise ValueError(f"Output channels of down block {i} must match input channels of up block " - f"{len(hparams.up_blocks) - (i + 1)} for skip mode '{hparams.skip_mode}'.") - - return hparams diff --git a/src/lightning_trainable/modules/unet/network.py b/src/lightning_trainable/modules/unet/network.py deleted file mode 100644 index 9f04d3f..0000000 --- a/src/lightning_trainable/modules/unet/network.py +++ /dev/null @@ -1,143 +0,0 @@ -import torch -import torch.nn as nn - -import math - -from lightning_trainable.hparams import AttributeDict - -from ..convolutional import ConvolutionalBlock -from ..fully_connected import FullyConnectedNetwork -from ..hparams_module import HParamsModule - -from .hparams import UNetHParams -from .skip_connection import SkipConnection -from .temporary_flatten import TemporaryFlatten - - -class UNet(HParamsModule): - hparams: UNetHParams - - def __init__(self, hparams: dict | UNetHParams): - super().__init__(hparams) - - self.down_blocks = nn.ModuleList() - self.up_blocks = nn.ModuleList() - self.levels = nn.ModuleList() - self.fc = None - - self.network = self.configure_network( - input_shape=self.hparams.input_shape, - output_shape=self.hparams.output_shape, - down_blocks=self.hparams.down_blocks, - up_blocks=self.hparams.up_blocks, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.network(x) - - def configure_levels(self, input_shape: (int, int, int), output_shape: (int, int, int), down_blocks: list[dict], up_blocks: list[dict]): - """ - Recursively configures the levels of the UNet. - """ - if not down_blocks: - self.fc = self.configure_fc(input_shape, output_shape) - return self.fc - - down_block = down_blocks[0] - up_block = up_blocks[-1] - - down_channels = [input_shape[0], *down_block["channels"]] - up_channels = [*up_block["channels"], output_shape[0]] - - down_hparams = AttributeDict( - channels=[input_shape[0], *down_block["channels"]], - kernel_sizes=down_block["kernel_sizes"], - activation=self.hparams.activation, - padding="same", - pool=True, - pool_direction="down", - pool_position="last" - ) - up_hparams = AttributeDict( - channels=[*up_block["channels"], output_shape[0]], - kernel_sizes=up_block["kernel_sizes"], - activation=self.hparams.activation, - padding="same", - pool=True, - pool_direction="up", - pool_position="first" - ) - - next_input_shape = (down_hparams.channels[-1], input_shape[1] // 2, input_shape[2] // 2) - next_output_shape = (up_hparams.channels[0], output_shape[1] // 2, output_shape[2] // 2) - - if self.hparams.skip_mode == "concat": - up_channels[0] += down_channels[-1] - up_hparams.channels[0] += down_hparams.channels[-1] - - down_block = ConvolutionalBlock(down_hparams) - up_block = ConvolutionalBlock(up_hparams) - - self.down_blocks.append(down_block) - self.up_blocks.append(up_block) - - next_level = self.configure_levels( - input_shape=next_input_shape, - output_shape=next_output_shape, - down_blocks=down_blocks[1:], - up_blocks=up_blocks[:-1], - ) - - level = nn.Sequential( - down_block, - SkipConnection(next_level, mode=self.hparams.skip_mode), - up_block, - ) - - self.levels.append(level) - - return level - - def configure_down_block(self, channels, kernel_sizes): - hparams = AttributeDict( - channels=channels, - kernel_sizes=kernel_sizes, - activation=self.hparams.activation, - padding="same", - pool=True, - pool_direction="down", - pool_position="last" - ) - - return ConvolutionalBlock(hparams) - - def configure_up_block(self, channels, kernel_sizes): - hparams = AttributeDict( - channels=channels, - kernel_sizes=kernel_sizes, - activation=self.hparams.activation, - padding="same", - pool=True, - pool_direction="up", - pool_position="first" - ) - - return ConvolutionalBlock(hparams) - - def configure_fc(self, input_shape: (int, int, int), output_shape: (int, int, int)): - """ - Configures the lowest level of the UNet as a fully connected network. - """ - hparams = dict( - input_dims=math.prod(input_shape), - output_dims=math.prod(output_shape), - layer_widths=self.hparams.bottom_widths, - activation=self.hparams.activation, - ) - return TemporaryFlatten(FullyConnectedNetwork(hparams), input_shape, output_shape) - - def configure_network(self, input_shape: (int, int, int), output_shape: (int, int, int), down_blocks: list[dict], up_blocks: list[dict]): - """ - Configures the UNet. - """ - return self.configure_levels(input_shape, output_shape, down_blocks, up_blocks) diff --git a/src/lightning_trainable/modules/unet/skip_connection.py b/src/lightning_trainable/modules/unet/skip_connection.py deleted file mode 100644 index c6ac170..0000000 --- a/src/lightning_trainable/modules/unet/skip_connection.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch -import torch.nn as nn - - -class SkipConnection(nn.Module): - def __init__(self, inner: nn.Module, mode: str = "add"): - super().__init__() - self.inner = inner - self.mode = mode - - def forward(self, x: torch.Tensor) -> torch.Tensor: - match self.mode: - case "add": - return self.inner(x) + x - case "concat": - return torch.cat((self.inner(x), x), dim=1) - case "none": - return self.inner(x) - case other: - raise NotImplementedError(f"Unrecognized skip connection mode '{other}'.") - - def extra_repr(self) -> str: - return f"mode={self.mode}" diff --git a/src/lightning_trainable/modules/unet/temporary_flatten.py b/src/lightning_trainable/modules/unet/temporary_flatten.py deleted file mode 100644 index 4b23244..0000000 --- a/src/lightning_trainable/modules/unet/temporary_flatten.py +++ /dev/null @@ -1,14 +0,0 @@ -import torch -import torch.nn as nn - - -class TemporaryFlatten(nn.Module): - def __init__(self, inner: nn.Module, input_shape, output_shape): - super().__init__() - self.inner = inner - self.input_shape = input_shape - self.output_shape = output_shape - - def forward(self, x: torch.Tensor) -> torch.Tensor: - out = self.inner(x.flatten(1)) - return out.reshape(x.shape[0], *self.output_shape) diff --git a/src/lightning_trainable/trainable/trainable.py b/src/lightning_trainable/trainable/trainable.py index e43c8ef..7a058fc 100644 --- a/src/lightning_trainable/trainable/trainable.py +++ b/src/lightning_trainable/trainable/trainable.py @@ -291,7 +291,14 @@ def fit_fast(self, device="cuda"): self.train() self.to(device) - optimizer = self.configure_optimizers()["optimizer"] + maybe_optimizer = self.configure_optimizers() + if isinstance(maybe_optimizer, dict): + optimizer = maybe_optimizer["optimizer"] + elif isinstance(maybe_optimizer, torch.optim.Optimizer): + optimizer = maybe_optimizer + else: + raise RuntimeError("Invalid optimizer") + dataloader = self.train_dataloader() loss = None diff --git a/tests/module_tests/__init__.py b/tests/module_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/module_tests/test_trainable.py b/tests/module_tests/test_trainable.py new file mode 100644 index 0000000..a64b18a --- /dev/null +++ b/tests/module_tests/test_trainable.py @@ -0,0 +1,73 @@ + +import pytest + +import torch +import torch.nn as nn +from torch.utils.data import TensorDataset + +from lightning_trainable.trainable import Trainable + + +@pytest.fixture +def dummy_dataset(): + return TensorDataset(torch.randn(128, 2)) + + +@pytest.fixture +def dummy_network(): + return nn.Linear(2, 2) + + +@pytest.fixture +def dummy_hparams(): + return dict( + max_epochs=10, + batch_size=4, + ) + + +@pytest.fixture +def dummy_model_cls(dummy_network, dummy_dataset): + class DummyModel(Trainable): + def __init__(self, hparams): + super().__init__(hparams, train_data=dummy_dataset, val_data=dummy_dataset, test_data=dummy_dataset) + self.network = dummy_network + + def compute_metrics(self, batch, batch_idx) -> dict: + return dict( + loss=torch.tensor(0.0, requires_grad=True) + ) + + return DummyModel + + +@pytest.fixture +def dummy_model(dummy_model_cls, dummy_hparams): + return dummy_model_cls(dummy_hparams) + + +def test_fit(dummy_model): + train_metrics = dummy_model.fit() + + assert isinstance(train_metrics, dict) + assert "training/loss" in train_metrics + assert "validation/loss" in train_metrics + + +def test_fit_fast(dummy_model): + loss = dummy_model.fit_fast() + + assert torch.isclose(loss, torch.tensor(0.0)) + + +def test_hparams_invariant(dummy_model_cls, dummy_hparams): + """ Ensure HParams are left unchanged after instantiation and training """ + hparams = dummy_hparams.copy() + + dummy_model1 = dummy_model_cls(hparams) + + assert hparams == dummy_hparams + + dummy_model1.fit() + + assert hparams == dummy_hparams diff --git a/tests/test_trainable.py b/tests/test_trainable.py deleted file mode 100644 index b6e7952..0000000 --- a/tests/test_trainable.py +++ /dev/null @@ -1,72 +0,0 @@ -import torch -from torch.utils.data import TensorDataset, Dataset - -from lightning_trainable import Trainable, TrainableHParams - - -def test_instantiate(): - hparams = TrainableHParams(max_epochs=10, batch_size=32) - Trainable(hparams) - - -def test_simple_model(): - class SimpleTrainable(Trainable): - def __init__(self, hparams: TrainableHParams | dict, - train_data: Dataset = None, - val_data: Dataset = None, - test_data: Dataset = None - ): - super().__init__(hparams, train_data, val_data, test_data) - self.param = torch.nn.Parameter(torch.randn(8, 1)) - - def compute_metrics(self, batch, batch_idx) -> dict: - return { - "loss": ((batch[0] @ self.param) ** 2).mean() - } - - train_data = TensorDataset(torch.randn(128, 8)) - - hparams = TrainableHParams( - accelerator="cpu", - max_epochs=10, - batch_size=32, - lr_scheduler="OneCycleLR" - ) - model = SimpleTrainable(hparams, train_data=train_data) - model.fit() - - model.load_checkpoint() - - -def test_double_train(): - class SimpleTrainable(Trainable): - def __init__(self, hparams: TrainableHParams | dict, - train_data: Dataset = None, - val_data: Dataset = None, - test_data: Dataset = None - ): - super().__init__(hparams, train_data, val_data, test_data) - self.param = torch.nn.Parameter(torch.randn(8, 1)) - - def compute_metrics(self, batch, batch_idx) -> dict: - return { - "loss": ((batch[0] @ self.param) ** 2).mean() - } - - hparams = TrainableHParams( - accelerator="cpu", - max_epochs=1, - batch_size=8, - optimizer=dict( - name="Adam", - lr=1e-3, - ) - ) - - train_data = TensorDataset(torch.randn(128, 8)) - - t1 = SimpleTrainable(hparams, train_data=train_data, val_data=train_data) - t1.fit() - - t2 = SimpleTrainable(hparams, train_data=train_data, val_data=train_data) - t2.fit() From 013f5a210830501284451ddabea6dbbafcc665c8 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Mon, 16 Oct 2023 16:33:15 +0200 Subject: [PATCH 32/53] Add more Tests + Remove Unnecessary Workaround --- src/lightning_trainable/modules/__init__.py | 1 - .../trainable/trainable.py | 7 +-- tests/module_tests/test_trainable.py | 46 +++++++++++--- tests/test_unet.py | 60 ------------------- 4 files changed, 42 insertions(+), 72 deletions(-) delete mode 100644 tests/test_unet.py diff --git a/src/lightning_trainable/modules/__init__.py b/src/lightning_trainable/modules/__init__.py index 2f4a8c8..a710cc6 100644 --- a/src/lightning_trainable/modules/__init__.py +++ b/src/lightning_trainable/modules/__init__.py @@ -1,4 +1,3 @@ -from .convolutional import ConvolutionalNetwork, ConvolutionalNetworkHParams from .fully_connected import FullyConnectedNetwork, FullyConnectedNetworkHParams from .hparams_module import HParamsModule from .simple_unet import SimpleUNet, SimpleUNetHParams diff --git a/src/lightning_trainable/trainable/trainable.py b/src/lightning_trainable/trainable/trainable.py index 814ca5f..16c62e3 100644 --- a/src/lightning_trainable/trainable/trainable.py +++ b/src/lightning_trainable/trainable/trainable.py @@ -1,8 +1,9 @@ +import lightning import os import pathlib - -import lightning import torch + +from copy import deepcopy from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping from lightning.pytorch.loggers import Logger, TensorBoardLogger from torch.utils.data import DataLoader, Dataset, IterableDataset @@ -34,8 +35,6 @@ def __init__( if not isinstance(hparams, self.hparams_type): hparams = self.hparams_type(**hparams) self.save_hyperparameters(hparams) - # workaround for https://github.com/Lightning-AI/lightning/issues/17889 - self._hparams_name = "hparams" self.train_data = train_data self.val_data = val_data diff --git a/tests/module_tests/test_trainable.py b/tests/module_tests/test_trainable.py index a64b18a..718e5bb 100644 --- a/tests/module_tests/test_trainable.py +++ b/tests/module_tests/test_trainable.py @@ -5,7 +5,7 @@ import torch.nn as nn from torch.utils.data import TensorDataset -from lightning_trainable.trainable import Trainable +from lightning_trainable.trainable import Trainable, TrainableHParams @pytest.fixture @@ -19,16 +19,23 @@ def dummy_network(): @pytest.fixture -def dummy_hparams(): - return dict( - max_epochs=10, - batch_size=4, - ) +def dummy_hparams_cls(): + class DummyHParams(TrainableHParams): + max_epochs: int = 10 + batch_size: int = 4 + return DummyHParams @pytest.fixture -def dummy_model_cls(dummy_network, dummy_dataset): +def dummy_hparams(dummy_hparams_cls): + return dummy_hparams_cls() + + +@pytest.fixture +def dummy_model_cls(dummy_network, dummy_dataset, dummy_hparams_cls): class DummyModel(Trainable): + hparams: dummy_hparams_cls + def __init__(self, hparams): super().__init__(hparams, train_data=dummy_dataset, val_data=dummy_dataset, test_data=dummy_dataset) self.network = dummy_network @@ -71,3 +78,28 @@ def test_hparams_invariant(dummy_model_cls, dummy_hparams): dummy_model1.fit() assert hparams == dummy_hparams + + +def test_checkpoint(dummy_model): + dummy_model.fit() + + trained_model = dummy_model.load_checkpoint() + + +def test_nested_checkpoint(dummy_model_cls, dummy_hparams_cls): + + class MyHParams(dummy_hparams_cls): + pass + + class MyModel(dummy_model_cls): + def __init__(self, hparams): + super().__init__(hparams) + + hparams = MyHParams() + model = MyModel(hparams) + + assert model._hparams_name == "hparams" + + model.fit() + + model.load_checkpoint() diff --git a/tests/test_unet.py b/tests/test_unet.py deleted file mode 100644 index 069b707..0000000 --- a/tests/test_unet.py +++ /dev/null @@ -1,60 +0,0 @@ - -import pytest - -import torch - -from lightning_trainable.modules import UNet, UNetHParams - - -@pytest.mark.parametrize("skip_mode", ["add", "concat", "none"]) -@pytest.mark.parametrize("width", [32, 48, 64]) -@pytest.mark.parametrize("height", [32, 48, 64]) -@pytest.mark.parametrize("channels", [1, 3, 5]) -def test_basic(skip_mode, channels, width, height): - hparams = UNetHParams( - input_shape=(channels, height, width), - output_shape=(channels, height, width), - down_blocks=[ - dict(channels=[16, 32], kernel_sizes=[3, 3, 3]), - dict(channels=[32, 64], kernel_sizes=[3, 3, 3]), - ], - up_blocks=[ - dict(channels=[64, 32], kernel_sizes=[3, 3, 3]), - dict(channels=[32, 16], kernel_sizes=[3, 3, 3]), - ], - bottom_widths=[32, 32], - skip_mode=skip_mode, - activation="relu", - ) - unet = UNet(hparams) - - x = torch.randn(1, *hparams.input_shape) - y = unet(x) - assert y.shape == (1, *hparams.output_shape) - - -@pytest.mark.parametrize("skip_mode", ["concat", "none"]) -@pytest.mark.parametrize("width", [32, 48, 64]) -@pytest.mark.parametrize("height", [32, 48, 64]) -@pytest.mark.parametrize("channels", [1, 3, 5]) -def test_inconsistent_channels(skip_mode, channels, width, height): - hparams = UNetHParams( - input_shape=(channels, height, width), - output_shape=(channels, height, width), - down_blocks=[ - dict(channels=[16, 24], kernel_sizes=[3, 3, 3]), - dict(channels=[48, 17], kernel_sizes=[3, 3, 3]), - ], - up_blocks=[ - dict(channels=[57, 31], kernel_sizes=[3, 3, 3]), - dict(channels=[12, 87], kernel_sizes=[3, 3, 3]), - ], - bottom_widths=[32, 32], - skip_mode=skip_mode, - activation="relu", - ) - unet = UNet(hparams) - - x = torch.randn(1, *hparams.input_shape) - y = unet(x) - assert y.shape == (1, *hparams.output_shape) From cfc1267fb2ba04010368ece67c79b7f51e8b6b8c Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Mon, 16 Oct 2023 17:03:41 +0200 Subject: [PATCH 33/53] Fix tests to use CPU instead of GPU for GitHub Workflow --- tests/module_tests/test_trainable.py | 4 +++- tests/test_launcher.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/module_tests/test_trainable.py b/tests/module_tests/test_trainable.py index 718e5bb..4f71513 100644 --- a/tests/module_tests/test_trainable.py +++ b/tests/module_tests/test_trainable.py @@ -23,9 +23,11 @@ def dummy_hparams_cls(): class DummyHParams(TrainableHParams): max_epochs: int = 10 batch_size: int = 4 + accelerator: str = "cpu" return DummyHParams + @pytest.fixture def dummy_hparams(dummy_hparams_cls): return dummy_hparams_cls() @@ -62,7 +64,7 @@ def test_fit(dummy_model): def test_fit_fast(dummy_model): - loss = dummy_model.fit_fast() + loss = dummy_model.fit_fast(device="cpu") assert torch.isclose(loss, torch.tensor(0.0)) diff --git a/tests/test_launcher.py b/tests/test_launcher.py index 605b448..bd21f22 100644 --- a/tests/test_launcher.py +++ b/tests/test_launcher.py @@ -9,6 +9,8 @@ from lightning_trainable.launcher.grid import GridLauncher, status_count_counter from lightning_trainable.launcher.utils import parse_config_dict +# TODO: tempdir + class BasicTrainableHParams(TrainableHParams): domain: list From f22131830cb3861fe34cb3279605655809132db4 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Mon, 16 Oct 2023 17:03:53 +0200 Subject: [PATCH 34/53] Remove TODO --- tests/test_launcher.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_launcher.py b/tests/test_launcher.py index bd21f22..605b448 100644 --- a/tests/test_launcher.py +++ b/tests/test_launcher.py @@ -9,8 +9,6 @@ from lightning_trainable.launcher.grid import GridLauncher, status_count_counter from lightning_trainable.launcher.utils import parse_config_dict -# TODO: tempdir - class BasicTrainableHParams(TrainableHParams): domain: list From 0b9ab7f1e5a20dbc856bfc8c1949e4bf645cb4ff Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Tue, 17 Oct 2023 11:44:22 +0200 Subject: [PATCH 35/53] Improve checkpointing errors, remove debug exception --- src/lightning_trainable/utils/io.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/lightning_trainable/utils/io.py b/src/lightning_trainable/utils/io.py index 73a0931..ffd8eed 100644 --- a/src/lightning_trainable/utils/io.py +++ b/src/lightning_trainable/utils/io.py @@ -9,6 +9,10 @@ def find_version(root: str | Path = "lightning_logs", version: int = "last") -> # Determine latest version number if "last" is passed as version number if version == "last": version_folders = [f for f in root.iterdir() if f.is_dir() and re.match(r"^version_(\d+)$", f.name)] + + if not version_folders: + raise FileNotFoundError(f"Found no folders in '{root}' matching the name pattern 'version_'") + version_numbers = [int(re.match(r"^version_(\d+)$", f.name).group(1)) for f in version_folders] version = max(version_numbers) @@ -25,7 +29,6 @@ def find_epoch_step(root: str | Path, epoch: int = "last", step: int = "last") - @param step: Step number or "last" @return: epoch and step numbers """ - root = Path(root) # get checkpoint filenames @@ -36,6 +39,10 @@ def find_epoch_step(root: str | Path, epoch: int = "last", step: int = "last") - # remove invalid files checkpoints = [cp for cp in checkpoints if pattern.match(cp)] + if not checkpoints: + raise FileNotFoundError(f"Found no checkpoints in '{root}' matching " + f"the name pattern 'epoch=-step=.ckpt'") + # get epochs and steps as list matches = [pattern.match(cp) for cp in checkpoints] epochs, steps = zip(*[(int(match.group(1)), int(match.group(2))) for match in matches]) @@ -43,7 +50,8 @@ def find_epoch_step(root: str | Path, epoch: int = "last", step: int = "last") - if epoch == "last": epoch = max(epochs) elif epoch not in epochs: - raise FileNotFoundError(f"No checkpoint in '{root}' for epoch '{epoch}'.") + closest_epoch = min(epochs, key=lambda e: abs(e - epoch)) + raise FileNotFoundError(f"No checkpoint in '{root}' for epoch '{epoch}'. Closest is '{closest_epoch}'.") # keep only steps for this epoch steps = [s for i, s in enumerate(steps) if epochs[i] == epoch] @@ -51,7 +59,9 @@ def find_epoch_step(root: str | Path, epoch: int = "last", step: int = "last") - if step == "last": step = max(steps) elif step not in steps: - raise FileNotFoundError(f"No checkpoint in '{root}' for epoch '{epoch}', step '{step}'") + closest_step = min(steps, key=lambda s: abs(s - step)) + raise FileNotFoundError(f"No checkpoint in '{root}' for epoch '{epoch}', step '{step}'. " + f"Closest step for this epoch is '{closest_step}'.") return epoch, step @@ -87,7 +97,4 @@ def find_checkpoint(root: str | Path = "lightning_logs", version: int = "last", checkpoint = checkpoint_folder / f"epoch={epoch}-step={step}.ckpt" - if not checkpoint.is_file(): - raise FileNotFoundError(f"{version=}, {epoch=}, {step=}") - return str(checkpoint) From 9ec3b51408d6d668639d1d91d9db4cfd66f8d42b Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Tue, 17 Oct 2023 11:45:55 +0200 Subject: [PATCH 36/53] Improve docs --- src/lightning_trainable/utils/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_trainable/utils/io.py b/src/lightning_trainable/utils/io.py index ffd8eed..7372a71 100644 --- a/src/lightning_trainable/utils/io.py +++ b/src/lightning_trainable/utils/io.py @@ -70,7 +70,7 @@ def find_checkpoint(root: str | Path = "lightning_logs", version: int = "last", """ Helper method to find a lightning checkpoint based on version, epoch and step numbers. - @param root: logs root directory. Usually "lightning_logs/" + @param root: logs root directory. Usually "lightning_logs/" or "lightning_logs/" @param version: version number or "last" @param epoch: epoch number or "last" @param step: step number or "last" From 3d27fbadc50a272e9fde1240d356047cd3cf2354 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Tue, 17 Oct 2023 11:54:27 +0200 Subject: [PATCH 37/53] Improve errors --- src/lightning_trainable/utils/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_trainable/utils/io.py b/src/lightning_trainable/utils/io.py index 7372a71..b44a4c2 100644 --- a/src/lightning_trainable/utils/io.py +++ b/src/lightning_trainable/utils/io.py @@ -79,7 +79,7 @@ def find_checkpoint(root: str | Path = "lightning_logs", version: int = "last", root = Path(root) if not root.is_dir(): - raise ValueError(f"Root directory '{root}' does not exist") + raise ValueError(f"Checkpoint root directory '{root}' does not exist") # get existing version number or error version = find_version(root, version) From dd3df1977110c4c4c9e94efbe0f16b2be687442a Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Tue, 17 Oct 2023 11:55:07 +0200 Subject: [PATCH 38/53] Rename for clarity `load_checkpoint` -> `find_and_load_from_checkpoint` --- src/lightning_trainable/trainable/trainable.py | 4 ++-- tests/module_tests/test_trainable.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lightning_trainable/trainable/trainable.py b/src/lightning_trainable/trainable/trainable.py index 16c62e3..181c5fc 100644 --- a/src/lightning_trainable/trainable/trainable.py +++ b/src/lightning_trainable/trainable/trainable.py @@ -320,8 +320,8 @@ def fit_fast(self, device="cuda"): return loss @classmethod - def load_checkpoint(cls, root: str | pathlib.Path = "lightning_logs", version: int | str = "last", - epoch: int | str = "last", step: int | str = "last", **kwargs): + def find_and_load_from_checkpoint(cls, root: str | pathlib.Path = "lightning_logs", version: int | str = "last", + epoch: int | str = "last", step: int | str = "last", **kwargs): checkpoint = utils.find_checkpoint(root, version, epoch, step) return cls.load_from_checkpoint(checkpoint, **kwargs) diff --git a/tests/module_tests/test_trainable.py b/tests/module_tests/test_trainable.py index 4f71513..a6f48c6 100644 --- a/tests/module_tests/test_trainable.py +++ b/tests/module_tests/test_trainable.py @@ -85,7 +85,7 @@ def test_hparams_invariant(dummy_model_cls, dummy_hparams): def test_checkpoint(dummy_model): dummy_model.fit() - trained_model = dummy_model.load_checkpoint() + trained_model = dummy_model.find_and_load_from_checkpoint() def test_nested_checkpoint(dummy_model_cls, dummy_hparams_cls): @@ -104,4 +104,4 @@ def __init__(self, hparams): model.fit() - model.load_checkpoint() + model.find_and_load_from_checkpoint() From e7f19c1e63dcf4a12868ece761d2c3f668d3cead Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Tue, 17 Oct 2023 12:12:23 +0200 Subject: [PATCH 39/53] Debugging GitHub Workflow --- src/lightning_trainable/utils/io.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/lightning_trainable/utils/io.py b/src/lightning_trainable/utils/io.py index b44a4c2..aff7fae 100644 --- a/src/lightning_trainable/utils/io.py +++ b/src/lightning_trainable/utils/io.py @@ -91,6 +91,15 @@ def find_checkpoint(root: str | Path = "lightning_logs", version: int = "last", checkpoint = checkpoint_folder / "last.ckpt" if checkpoint.is_file(): return str(checkpoint) + else: + contents = "\n".join([str(p) for p in checkpoint_folder.iterdir()]) + msg = f""" + Could not find 'last.ckpt' in '{checkpoint_folder}'. + Checkpoint folder exists? {checkpoint_folder.is_dir()} + Checkpoint folder contents: + {contents} + """ + raise RuntimeError(msg) # get existing epoch and step number or error epoch, step = find_epoch_step(checkpoint_folder, epoch, step) From 3849e08c677a6d76dc0301a32f41094864367a2d Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Tue, 17 Oct 2023 13:23:58 +0200 Subject: [PATCH 40/53] Debugging GitHub Workflow --- src/lightning_trainable/utils/io.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lightning_trainable/utils/io.py b/src/lightning_trainable/utils/io.py index aff7fae..749a9e4 100644 --- a/src/lightning_trainable/utils/io.py +++ b/src/lightning_trainable/utils/io.py @@ -86,13 +86,15 @@ def find_checkpoint(root: str | Path = "lightning_logs", version: int = "last", checkpoint_folder = root / f"version_{version}" / "checkpoints" + contents = "\n".join([str(p) for p in checkpoint_folder.iterdir()]) + if epoch == "last" and step == "last": # return last.ckpt if it exists checkpoint = checkpoint_folder / "last.ckpt" if checkpoint.is_file(): return str(checkpoint) else: - contents = "\n".join([str(p) for p in checkpoint_folder.iterdir()]) + # TODO: remove debug error msg = f""" Could not find 'last.ckpt' in '{checkpoint_folder}'. Checkpoint folder exists? {checkpoint_folder.is_dir()} From 205b688d8bef680cf47bebd5624652d4724b8db8 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Tue, 17 Oct 2023 13:27:32 +0200 Subject: [PATCH 41/53] Debugging GitHub Workflow --- src/lightning_trainable/utils/io.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lightning_trainable/utils/io.py b/src/lightning_trainable/utils/io.py index 749a9e4..643c997 100644 --- a/src/lightning_trainable/utils/io.py +++ b/src/lightning_trainable/utils/io.py @@ -98,6 +98,9 @@ def find_checkpoint(root: str | Path = "lightning_logs", version: int = "last", msg = f""" Could not find 'last.ckpt' in '{checkpoint_folder}'. Checkpoint folder exists? {checkpoint_folder.is_dir()} + Checkpoint exists? {checkpoint.exists()} + Checkpoint is file? {checkpoint.is_file()} + Checkpoint path: {checkpoint} Checkpoint folder contents: {contents} """ From 15c7f6c3ab694061031ad21fc19a6c15224c8f39 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Tue, 17 Oct 2023 13:44:26 +0200 Subject: [PATCH 42/53] Debugging GitHub Workflow Trying os instead of pathlib --- src/lightning_trainable/utils/io.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lightning_trainable/utils/io.py b/src/lightning_trainable/utils/io.py index 643c997..24ae436 100644 --- a/src/lightning_trainable/utils/io.py +++ b/src/lightning_trainable/utils/io.py @@ -88,10 +88,12 @@ def find_checkpoint(root: str | Path = "lightning_logs", version: int = "last", contents = "\n".join([str(p) for p in checkpoint_folder.iterdir()]) + import os + if epoch == "last" and step == "last": # return last.ckpt if it exists checkpoint = checkpoint_folder / "last.ckpt" - if checkpoint.is_file(): + if os.path.isfile(str(checkpoint)): return str(checkpoint) else: # TODO: remove debug error From ce4bf1018332ef19069f2c106cf83337d1a8f2b0 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Tue, 17 Oct 2023 13:52:32 +0200 Subject: [PATCH 43/53] Debugging GitHub Workflow Resolving checkpoint path for more info --- src/lightning_trainable/utils/io.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/lightning_trainable/utils/io.py b/src/lightning_trainable/utils/io.py index 24ae436..4155121 100644 --- a/src/lightning_trainable/utils/io.py +++ b/src/lightning_trainable/utils/io.py @@ -88,12 +88,11 @@ def find_checkpoint(root: str | Path = "lightning_logs", version: int = "last", contents = "\n".join([str(p) for p in checkpoint_folder.iterdir()]) - import os - if epoch == "last" and step == "last": # return last.ckpt if it exists checkpoint = checkpoint_folder / "last.ckpt" - if os.path.isfile(str(checkpoint)): + checkpoint = checkpoint.resolve() + if checkpoint.is_file(): return str(checkpoint) else: # TODO: remove debug error From 6a9ac9723d0e313eeef1f903f2910cbf3f679cb9 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Tue, 17 Oct 2023 14:10:10 +0200 Subject: [PATCH 44/53] Turns out, GitHub is just broken Seems like a symlink loop on GitHub's end --- src/lightning_trainable/utils/io.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/lightning_trainable/utils/io.py b/src/lightning_trainable/utils/io.py index 4155121..b44a4c2 100644 --- a/src/lightning_trainable/utils/io.py +++ b/src/lightning_trainable/utils/io.py @@ -86,26 +86,11 @@ def find_checkpoint(root: str | Path = "lightning_logs", version: int = "last", checkpoint_folder = root / f"version_{version}" / "checkpoints" - contents = "\n".join([str(p) for p in checkpoint_folder.iterdir()]) - if epoch == "last" and step == "last": # return last.ckpt if it exists checkpoint = checkpoint_folder / "last.ckpt" - checkpoint = checkpoint.resolve() if checkpoint.is_file(): return str(checkpoint) - else: - # TODO: remove debug error - msg = f""" - Could not find 'last.ckpt' in '{checkpoint_folder}'. - Checkpoint folder exists? {checkpoint_folder.is_dir()} - Checkpoint exists? {checkpoint.exists()} - Checkpoint is file? {checkpoint.is_file()} - Checkpoint path: {checkpoint} - Checkpoint folder contents: - {contents} - """ - raise RuntimeError(msg) # get existing epoch and step number or error epoch, step = find_epoch_step(checkpoint_folder, epoch, step) From 2d467a755575ede3c0a500e13a000650bd41c048 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Tue, 17 Oct 2023 16:01:54 +0200 Subject: [PATCH 45/53] Improved Tests, add WIP fix for #10 --- .../trainable/trainable.py | 74 ++++++++----------- tests/module_tests/test_trainable.py | 31 +++++++- 2 files changed, 60 insertions(+), 45 deletions(-) diff --git a/src/lightning_trainable/trainable/trainable.py b/src/lightning_trainable/trainable/trainable.py index 181c5fc..079612f 100644 --- a/src/lightning_trainable/trainable/trainable.py +++ b/src/lightning_trainable/trainable/trainable.py @@ -62,6 +62,10 @@ def compute_metrics(self, batch, batch_idx) -> dict: """ raise NotImplementedError + def on_train_start(self): + # TODO: get metrics to track from compute_metrics + self.logger.log_hyperparams(self.hparams, ...) + def training_step(self, batch, batch_idx): try: metrics = self.compute_metrics(batch, batch_idx) @@ -186,28 +190,22 @@ def test_dataloader(self) -> DataLoader | list[DataLoader]: num_workers=self.hparams.num_workers, ) - def configure_logger(self, save_dir=os.getcwd(), **kwargs) -> Logger: + def configure_logger(self, logger_name: str = "TensorBoardLogger", **logger_kwargs) -> Logger: """ Instantiate the Logger used by the Trainer in module fitting. By default, we use a TensorBoardLogger, but you can use any other logger of your choice. - @param save_dir: The root directory in which all your experiments with - different names and versions will be stored. - @param kwargs: Keyword-Arguments to the Logger. Set `logger_name` to use a different logger than TensorBoardLogger. + @param logger_name: The name of the logger to use. Defaults to TensorBoardLogger. + @param logger_kwargs: Keyword-Arguments to the Logger. Set `logger_name` to use a different logger than TensorBoardLogger. @return: The Logger object. """ - logger_kwargs = deepcopy(kwargs) - logger_kwargs.update(dict( - save_dir=save_dir, - )) - logger_name = logger_kwargs.pop("logger_name", "TensorBoardLogger") + logger_kwargs = logger_kwargs or {} + logger_kwargs.setdefault("save_dir", os.getcwd()) logger_class = utils.get_logger(logger_name) if issubclass(logger_class, TensorBoardLogger): - logger_kwargs["default_hp_metric"] = False + logger_kwargs.setdefault("default_hp_metric", False) - return logger_class( - **logger_kwargs - ) + return logger_class(**logger_kwargs) def configure_trainer(self, logger_kwargs: dict = None, trainer_kwargs: dict = None) -> lightning.Trainer: """ @@ -219,23 +217,20 @@ def configure_trainer(self, logger_kwargs: dict = None, trainer_kwargs: dict = N See also :func:`~trainable.Trainable.configure_trainer`. @return: The Lightning Trainer object. """ - if logger_kwargs is None: - logger_kwargs = dict() - if trainer_kwargs is None: - trainer_kwargs = dict() - - return lightning.Trainer( - accelerator=self.hparams.accelerator.lower(), - logger=self.configure_logger(**logger_kwargs), - devices=self.hparams.devices, - max_epochs=self.hparams.max_epochs, - max_steps=self.hparams.max_steps, - gradient_clip_val=self.hparams.gradient_clip, - accumulate_grad_batches=self.hparams.accumulate_batches, - profiler=self.hparams.profiler, - benchmark=True, - **trainer_kwargs, - ) + logger_kwargs = logger_kwargs or {} + trainer_kwargs = trainer_kwargs or {} + + trainer_kwargs.setdefault("accelerator", self.hparams.accelerator.lower()) + trainer_kwargs.setdefault("accumulate_grad_batches", self.hparams.accumulate_batches) + trainer_kwargs.setdefault("benchmark", True) + trainer_kwargs.setdefault("devices", self.hparams.devices) + trainer_kwargs.setdefault("gradient_clip_val", self.hparams.gradient_clip) + trainer_kwargs.setdefault("logger", self.configure_logger(**logger_kwargs)) + trainer_kwargs.setdefault("max_epochs", self.hparams.max_epochs) + trainer_kwargs.setdefault("max_steps", self.hparams.max_steps) + trainer_kwargs.setdefault("profiler", self.hparams.profiler) + + return lightning.Trainer(**trainer_kwargs) def on_before_optimizer_step(self, optimizer): # who doesn't love breaking changes in underlying libraries @@ -260,20 +255,21 @@ def fit(self, logger_kwargs: dict = None, trainer_kwargs: dict = None, fit_kwarg @param fit_kwargs: Keyword-Arguments to the Trainer's fit method. @return: Validation Metrics as defined in :func:`~trainable.Trainable.compute_metrics`. """ - if logger_kwargs is None: - logger_kwargs = dict() - if trainer_kwargs is None: - trainer_kwargs = dict() - if fit_kwargs is None: - fit_kwargs = dict() + logger_kwargs = logger_kwargs or {} + trainer_kwargs = trainer_kwargs or {} + fit_kwargs = fit_kwargs or {} trainer = self.configure_trainer(logger_kwargs, trainer_kwargs) + + # TODO: move this to on_train_start(), see https://lightning.ai/docs/pytorch/stable/extensions/logging.html#logging-hyperparameters metrics_list = trainer.validate(self) if metrics_list is not None and len(metrics_list) > 0: metrics = metrics_list[0] else: metrics = {} trainer.logger.log_hyperparams(self.hparams, metrics) + + trainer.fit(self, **fit_kwargs) return { @@ -319,12 +315,6 @@ def fit_fast(self, device="cuda"): return loss - @classmethod - def find_and_load_from_checkpoint(cls, root: str | pathlib.Path = "lightning_logs", version: int | str = "last", - epoch: int | str = "last", step: int | str = "last", **kwargs): - checkpoint = utils.find_checkpoint(root, version, epoch, step) - return cls.load_from_checkpoint(checkpoint, **kwargs) - def auto_pin_memory(pin_memory: bool | None, accelerator: str): if pin_memory is None: diff --git a/tests/module_tests/test_trainable.py b/tests/module_tests/test_trainable.py index a6f48c6..c424ab8 100644 --- a/tests/module_tests/test_trainable.py +++ b/tests/module_tests/test_trainable.py @@ -5,7 +5,10 @@ import torch.nn as nn from torch.utils.data import TensorDataset +from pathlib import Path + from lightning_trainable.trainable import Trainable, TrainableHParams +from lightning_trainable.utils import find_checkpoint @pytest.fixture @@ -83,9 +86,13 @@ def test_hparams_invariant(dummy_model_cls, dummy_hparams): def test_checkpoint(dummy_model): + # TODO: temp directory + dummy_model.fit() - trained_model = dummy_model.find_and_load_from_checkpoint() + checkpoint = find_checkpoint() + + assert Path(checkpoint).is_file() def test_nested_checkpoint(dummy_model_cls, dummy_hparams_cls): @@ -94,6 +101,8 @@ class MyHParams(dummy_hparams_cls): pass class MyModel(dummy_model_cls): + hparams: MyHParams + def __init__(self, hparams): super().__init__(hparams) @@ -102,6 +111,22 @@ def __init__(self, hparams): assert model._hparams_name == "hparams" - model.fit() - model.find_and_load_from_checkpoint() +def test_continue_training(dummy_model): + print("Starting Training.") + dummy_model.fit() + + print("Finished Training. Loading Checkpoint.") + checkpoint = find_checkpoint() + + trained_model = dummy_model.load_from_checkpoint(checkpoint) + + print("Continuing Training.") + trained_model.fit( + trainer_kwargs=dict(max_epochs=2 * dummy_model.hparams.max_epochs), + fit_kwargs=dict(ckpt_path=checkpoint) + ) + + print("Finished Continued Training.") + + # TODO: add check that the model was actually trained for 2x epochs From eb9a4c7f1c36dc674d2a1b6a05d1eedcb1a7bd3b Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Tue, 17 Oct 2023 16:02:15 +0200 Subject: [PATCH 46/53] Add macos and python 3.11 to workflow --- .github/workflows/tests.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 36f78e4..14fcbe2 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -9,8 +9,8 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, windows-latest] - python-version: ["3.10"] + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.10", "3.11"] steps: - uses: actions/checkout@v3 From e63137a7acfc94435d7fcfeb839e23420bcb6ec4 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Fri, 27 Oct 2023 14:46:43 +0200 Subject: [PATCH 47/53] Added test for and fixed #13 --- src/lightning_trainable/hparams/attribute_dict.py | 5 +++++ tests/module_tests/test_trainable.py | 13 +++++++++++++ 2 files changed, 18 insertions(+) diff --git a/src/lightning_trainable/hparams/attribute_dict.py b/src/lightning_trainable/hparams/attribute_dict.py index dc484a1..98a4534 100644 --- a/src/lightning_trainable/hparams/attribute_dict.py +++ b/src/lightning_trainable/hparams/attribute_dict.py @@ -14,3 +14,8 @@ def __getattribute__(self, item): def __setattr__(self, key, value): self[key] = value + + def copy(self): + # copies of AttributeDicts should be AttributeDicts + # see also https://github.com/LarsKue/lightning-trainable/issues/13 + return self.__class__(**super().copy()) diff --git a/tests/module_tests/test_trainable.py b/tests/module_tests/test_trainable.py index c424ab8..a7bb80c 100644 --- a/tests/module_tests/test_trainable.py +++ b/tests/module_tests/test_trainable.py @@ -7,6 +7,7 @@ from pathlib import Path +from lightning_trainable.hparams import HParams from lightning_trainable.trainable import Trainable, TrainableHParams from lightning_trainable.utils import find_checkpoint @@ -72,6 +73,18 @@ def test_fit_fast(dummy_model): assert torch.isclose(loss, torch.tensor(0.0)) +def test_hparams_copy(dummy_hparams, dummy_hparams_cls): + assert isinstance(dummy_hparams, HParams) + assert isinstance(dummy_hparams, TrainableHParams) + assert isinstance(dummy_hparams, dummy_hparams_cls) + + hparams_copy = dummy_hparams.copy() + + assert isinstance(hparams_copy, HParams) + assert isinstance(dummy_hparams, TrainableHParams) + assert isinstance(dummy_hparams, dummy_hparams_cls) + + def test_hparams_invariant(dummy_model_cls, dummy_hparams): """ Ensure HParams are left unchanged after instantiation and training """ hparams = dummy_hparams.copy() From 927ee5f872f20dd4c164981d3358fe59eea70702 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Fri, 27 Oct 2023 15:31:47 +0200 Subject: [PATCH 48/53] Fix default value for optimizer --- src/lightning_trainable/trainable/trainable_hparams.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_trainable/trainable/trainable_hparams.py b/src/lightning_trainable/trainable/trainable_hparams.py index 66c000a..d27a455 100644 --- a/src/lightning_trainable/trainable/trainable_hparams.py +++ b/src/lightning_trainable/trainable/trainable_hparams.py @@ -12,7 +12,7 @@ class TrainableHParams(HParams): devices: int = 1 max_epochs: int | None max_steps: int = -1 - optimizer: str | dict | None = "adam" + optimizer: str | dict | None = "Adam" lr_scheduler: str | dict | None = None batch_size: int accumulate_batches: int = 1 From a17f51ea2f6fb3c9274951a369212f47a1abff38 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Fri, 27 Oct 2023 15:33:11 +0200 Subject: [PATCH 49/53] Clean up --- src/lightning_trainable/trainable/trainable.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/lightning_trainable/trainable/trainable.py b/src/lightning_trainable/trainable/trainable.py index 079612f..bff3e36 100644 --- a/src/lightning_trainable/trainable/trainable.py +++ b/src/lightning_trainable/trainable/trainable.py @@ -1,6 +1,5 @@ import lightning import os -import pathlib import torch from copy import deepcopy @@ -62,10 +61,6 @@ def compute_metrics(self, batch, batch_idx) -> dict: """ raise NotImplementedError - def on_train_start(self): - # TODO: get metrics to track from compute_metrics - self.logger.log_hyperparams(self.hparams, ...) - def training_step(self, batch, batch_idx): try: metrics = self.compute_metrics(batch, batch_idx) From eb29b1603999c91fa0d4d157a1a139302143901e Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Fri, 27 Oct 2023 15:34:43 +0200 Subject: [PATCH 50/53] Move hparam logging to on_train_start with test batch from train_dataloader --- src/lightning_trainable/trainable/trainable.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/lightning_trainable/trainable/trainable.py b/src/lightning_trainable/trainable/trainable.py index bff3e36..fc094e4 100644 --- a/src/lightning_trainable/trainable/trainable.py +++ b/src/lightning_trainable/trainable/trainable.py @@ -238,6 +238,14 @@ def on_before_optimizer_step(self, optimizer): case other: raise NotImplementedError(f"Unrecognized grad norm: {other}") + def on_train_start(self) -> None: + # get hparams metrics with a test batch + test_batch = next(iter(self.trainer.train_dataloader)) + metrics = self.compute_metrics(test_batch, 0) + + # add hparams to tensorboard + self.logger.log_hyperparams(self.hparams, metrics) + @torch.enable_grad() def fit(self, logger_kwargs: dict = None, trainer_kwargs: dict = None, fit_kwargs: dict = None) -> dict: """ @@ -256,15 +264,6 @@ def fit(self, logger_kwargs: dict = None, trainer_kwargs: dict = None, fit_kwarg trainer = self.configure_trainer(logger_kwargs, trainer_kwargs) - # TODO: move this to on_train_start(), see https://lightning.ai/docs/pytorch/stable/extensions/logging.html#logging-hyperparameters - metrics_list = trainer.validate(self) - if metrics_list is not None and len(metrics_list) > 0: - metrics = metrics_list[0] - else: - metrics = {} - trainer.logger.log_hyperparams(self.hparams, metrics) - - trainer.fit(self, **fit_kwargs) return { From fa842a3970739155a33608d5957b3cbdfc0c3c0f Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Fri, 27 Oct 2023 15:35:00 +0200 Subject: [PATCH 51/53] Fix minor error in fit return --- src/lightning_trainable/trainable/trainable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_trainable/trainable/trainable.py b/src/lightning_trainable/trainable/trainable.py index fc094e4..9684043 100644 --- a/src/lightning_trainable/trainable/trainable.py +++ b/src/lightning_trainable/trainable/trainable.py @@ -269,7 +269,7 @@ def fit(self, logger_kwargs: dict = None, trainer_kwargs: dict = None, fit_kwarg return { key: value.item() for key, value in trainer.callback_metrics.items() - if any(key.startswith(key) for key in ["training/", "validation/"]) + if any(key.startswith(k) for k in ["training/", "validation/"]) } @torch.enable_grad() From b87f433ef51b43eafed93ff6c75d0d89d226c9ea Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Fri, 27 Oct 2023 15:45:09 +0200 Subject: [PATCH 52/53] Fix keywords apparently spaces are no longer allowed --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 32ff79a..0cef9e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ description = "A light-weight lightning_trainable module for pytorch-lightning." readme = "README.md" requires-python = ">=3.10" license = { file = "LICENSE" } -keywords = ["Machine Learning", "PyTorch", "PyTorch-Lightning"] +keywords = ["Machine-Learning", "PyTorch", "PyTorch-Lightning"] authors = [ { name = "Lars Kühmichel", email = "lars.kuehmichel@stud.uni-heidelberg.de" } From 935ffddd6c4c4120e10d679074d7903c56789c16 Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Fri, 27 Oct 2023 16:24:44 +0200 Subject: [PATCH 53/53] Fix test for lightning 2.1 --- tests/module_tests/test_trainable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/module_tests/test_trainable.py b/tests/module_tests/test_trainable.py index a7bb80c..31283cf 100644 --- a/tests/module_tests/test_trainable.py +++ b/tests/module_tests/test_trainable.py @@ -132,7 +132,7 @@ def test_continue_training(dummy_model): print("Finished Training. Loading Checkpoint.") checkpoint = find_checkpoint() - trained_model = dummy_model.load_from_checkpoint(checkpoint) + trained_model = dummy_model.__class__.load_from_checkpoint(checkpoint) print("Continuing Training.") trained_model.fit(