diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 36f78e4..14fcbe2 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -9,8 +9,8 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, windows-latest] - python-version: ["3.10"] + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.10", "3.11"] steps: - uses: actions/checkout@v3 diff --git a/pyproject.toml b/pyproject.toml index 2004e9e..0cef9e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ description = "A light-weight lightning_trainable module for pytorch-lightning." readme = "README.md" requires-python = ">=3.10" license = { file = "LICENSE" } -keywords = ["Machine Learning", "PyTorch", "PyTorch-Lightning"] +keywords = ["Machine-Learning", "PyTorch", "PyTorch-Lightning"] authors = [ { name = "Lars Kühmichel", email = "lars.kuehmichel@stud.uni-heidelberg.de" } @@ -57,7 +57,9 @@ tests = [ ] experiments = [ - # these are just recommended packages to run experiments + # required + "ray[tune] ~= 2.4", + # recommended "numpy ~= 1.24", "matplotlib ~= 3.7", "jupyterlab ~= 3.6", diff --git a/src/lightning_trainable/callbacks/epoch_progress_bar.py b/src/lightning_trainable/callbacks/epoch_progress_bar.py index 9dd973f..a375c97 100644 --- a/src/lightning_trainable/callbacks/epoch_progress_bar.py +++ b/src/lightning_trainable/callbacks/epoch_progress_bar.py @@ -2,7 +2,11 @@ from lightning.pytorch.callbacks import ProgressBar from lightning.pytorch.callbacks.progress.tqdm_progress import Tqdm +from lightning_trainable.utils import deprecate + +@deprecate("EpochProgressBar causes issues when continuing training or using multi-GPU. " + "Use the default Lightning ProgressBar instead.") class EpochProgressBar(ProgressBar): def __init__(self): super().__init__() @@ -25,6 +29,8 @@ def on_train_end(self, trainer, pl_module): self.bar.close() +@deprecate("StepProgressBar causes issues when continuing training or using multi-GPU. " + "Use the default Lightning ProgressBar instead.") class StepProgressBar(ProgressBar): def __init__(self): super().__init__() diff --git a/src/lightning_trainable/datasets/core/__init__.py b/src/lightning_trainable/datasets/core/__init__.py index a0f49bf..6e526bd 100644 --- a/src/lightning_trainable/datasets/core/__init__.py +++ b/src/lightning_trainable/datasets/core/__init__.py @@ -1 +1,2 @@ +from .distribution_dataset import DistributionDataset from .joint import JointDataset, JointIterableDataset diff --git a/src/lightning_trainable/hparams/attribute_dict.py b/src/lightning_trainable/hparams/attribute_dict.py index dc484a1..98a4534 100644 --- a/src/lightning_trainable/hparams/attribute_dict.py +++ b/src/lightning_trainable/hparams/attribute_dict.py @@ -14,3 +14,8 @@ def __getattribute__(self, item): def __setattr__(self, key, value): self[key] = value + + def copy(self): + # copies of AttributeDicts should be AttributeDicts + # see also https://github.com/LarsKue/lightning-trainable/issues/13 + return self.__class__(**super().copy()) diff --git a/src/lightning_trainable/hparams/types/choice.py b/src/lightning_trainable/hparams/types/choice.py index 4c0c545..00315ca 100644 --- a/src/lightning_trainable/hparams/types/choice.py +++ b/src/lightning_trainable/hparams/types/choice.py @@ -11,6 +11,9 @@ def __call__(cls, *choices): namespace = {"choices": choices} return type(name, bases, namespace) + def __repr__(cls): + return f"Choice{cls.choices!r}" + class Choice(metaclass=ChoiceMeta): """ diff --git a/src/lightning_trainable/hparams/types/range.py b/src/lightning_trainable/hparams/types/range.py index b982b93..a20b739 100644 --- a/src/lightning_trainable/hparams/types/range.py +++ b/src/lightning_trainable/hparams/types/range.py @@ -21,6 +21,9 @@ def __call__(cls, lower: float | int, upper: float | int, exclude: str | None = namespace = {"lower": lower, "upper": upper, "exclude": exclude} return type(name, bases, namespace) + def __repr__(self): + return f"Range({self.lower!r}, {self.upper!r}, exclude={self.exclude!r})" + class Range(metaclass=RangeMeta): """ diff --git a/src/lightning_trainable/metrics/__init__.py b/src/lightning_trainable/metrics/__init__.py index 06a1dab..30bcacd 100644 --- a/src/lightning_trainable/metrics/__init__.py +++ b/src/lightning_trainable/metrics/__init__.py @@ -1 +1,4 @@ from .accuracy import accuracy +from .error import error +from .sinkhorn import sinkhorn_auto as sinkhorn +from .wasserstein import wasserstein diff --git a/src/lightning_trainable/metrics/error.py b/src/lightning_trainable/metrics/error.py new file mode 100644 index 0000000..ec4b458 --- /dev/null +++ b/src/lightning_trainable/metrics/error.py @@ -0,0 +1,6 @@ + +from .accuracy import accuracy + + +def error(logits, targets, *, k=1): + return 1.0 - accuracy(logits, targets, k=k) diff --git a/src/lightning_trainable/metrics/sinkhorn.py b/src/lightning_trainable/metrics/sinkhorn.py new file mode 100644 index 0000000..623ecdc --- /dev/null +++ b/src/lightning_trainable/metrics/sinkhorn.py @@ -0,0 +1,107 @@ +import warnings + +import torch +from torch import Tensor + +import torch.nn.functional as F +import numpy as np + + +def sinkhorn(a: Tensor, b: Tensor, cost: Tensor, epsilon: float, steps: int = 1000) -> Tensor: + """ + Computes the Sinkhorn optimal transport plan from sample weights of two distributions. + This version does not use log-space computations, but allows for zero or negative weights. + + @param a: Sample weights from the first distribution in shape (n,) + @param b: Sample weights from the second distribution in shape (m,) + @param cost: Cost matrix in shape (n, m). + @param epsilon: Entropic regularization parameter. + @param steps: Number of iterations. + """ + if cost.shape != (len(a), len(b)): + raise ValueError(f"Expected cost to have shape {(len(a), len(b))}, but got {cost.shape}.") + + gain = torch.exp(-cost / epsilon) + + if gain.mean() < 1e-30: + warnings.warn(f"Detected low bandwidth ({epsilon:.1e}) relative to cost ({cost.mean().item():.1e}). " + f"You may experience numerical instabilities. Consider increasing epsilon or using sinkhorn_log.") + + # Initialize the dual variables. + u = torch.ones(len(a), dtype=a.dtype, device=a.device) + v = torch.ones(len(b), dtype=b.dtype, device=b.device) + + # Compute the Sinkhorn iterations. + for _ in range(steps): + v = b / (torch.matmul(gain.T, u) + 1e-50) + u = a / (torch.matmul(gain, v) + 1e-50) + + # Return the transport plan. + return u[:, None] * gain * v[None, :] + + +def sinkhorn_log(log_a: Tensor, log_b: Tensor, cost: Tensor, epsilon: float, steps: int = 1000) -> Tensor: + """ + Computes the Sinkhorn optimal transport plan from sample weights of two distributions. + This version uses log-space computations to avoid numerical instabilities, but disallows zero or negative weights. + + @param log_a: Log sample weights from the first distribution in shape (n,) + @param log_b: Log sample weights from the second distribution in shape (m,) + @param cost: Cost matrix in shape (n, m). + @param epsilon: Entropic regularization parameter. + @param steps: Number of iterations. + """ + if cost.shape != (len(log_a), len(log_b)): + raise ValueError(f"Expected cost to have shape {(len(log_a), len(log_b))}, but got {cost.shape}.") + + log_gain = -cost / epsilon + + # Initialize the dual variables. + log_u = torch.zeros(len(log_a), dtype=log_a.dtype, device=log_a.device) + log_v = torch.zeros(len(log_b), dtype=log_b.dtype, device=log_b.device) + + # Compute the Sinkhorn iterations. + for _ in range(steps): + log_v = log_b - torch.logsumexp(log_gain + log_u[:, None], dim=0) + log_u = log_a - torch.logsumexp(log_gain + log_v[None, :], dim=1) + + plan = torch.exp(log_u[:, None] + log_gain + log_v[None, :]) + + if not torch.allclose(len(log_b) * plan.sum(dim=0), torch.ones(len(log_b), device=plan.device)) or not torch.allclose(len(log_a) * plan.sum(dim=1), torch.ones(len(log_a), device=plan.device)): + warnings.warn(f"Sinkhorn did not converge. Consider increasing epsilon or number of iterations.") + + # Return the transport plan. + return plan + + +def sinkhorn_auto(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = 1.0, steps: int = 1000) -> Tensor: + """ + Computes the Sinkhorn optimal transport plan from samples from two distributions. + See also: sinkhorn_log + + @param x: Samples from the first distribution in shape (n, ...). + @param y: Samples from the second distribution in shape (m, ...). + @param cost: Optional cost matrix in shape (n, m). + If not provided, the Euclidean distance is used. + @param epsilon: Optional entropic regularization parameter. + This parameter is normalized to the half-mean of the cost matrix. This helps keep the value independent + of the data dimensionality. Note that this behaviour is exclusive to this method; sinkhorn_log only accepts + the raw entropic regularization value. + @param steps: Number of iterations. + """ + if x.shape[1:] != y.shape[1:]: + raise ValueError(f"Expected x and y to live in the same feature space, " + f"but got {x.shape[1:]} and {y.shape[1:]}.") + if cost is None: + cost = x[:, None] - y[None, :] + cost = torch.flatten(cost, start_dim=2) + cost = torch.linalg.norm(cost, dim=-1) + + # Initialize epsilon independent of the data dimension (i.e. dependent on the mean cost) + epsilon = epsilon * cost.mean() / 2 + + # Initialize the sample weights. + log_a = torch.zeros(len(x), device=x.device) - np.log(len(x)) + log_b = torch.zeros(len(y), device=y.device) - np.log(len(y)) + + return sinkhorn_log(log_a, log_b, cost, epsilon, steps) diff --git a/src/lightning_trainable/metrics/wasserstein.py b/src/lightning_trainable/metrics/wasserstein.py new file mode 100644 index 0000000..1fdb447 --- /dev/null +++ b/src/lightning_trainable/metrics/wasserstein.py @@ -0,0 +1,18 @@ + +import torch +from torch import Tensor + +from .sinkhorn import sinkhorn_auto + + +def wasserstein(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = 0.1, steps: int = 10) -> Tensor: + """ + Computes the Wasserstein distance between two distributions. + See also: sinkhorn_auto + """ + if cost is None: + cost = x[:, None] - y[None, :] + cost = torch.flatten(cost, start_dim=2) + cost = torch.linalg.norm(cost, dim=-1) + + return torch.sum(sinkhorn_auto(x, y, cost, epsilon, steps) * cost) diff --git a/src/lightning_trainable/modules/__init__.py b/src/lightning_trainable/modules/__init__.py index b5cb254..a710cc6 100644 --- a/src/lightning_trainable/modules/__init__.py +++ b/src/lightning_trainable/modules/__init__.py @@ -1,4 +1,3 @@ -from .convolutional import ConvolutionalNetwork, ConvolutionalNetworkHParams from .fully_connected import FullyConnectedNetwork, FullyConnectedNetworkHParams from .hparams_module import HParamsModule -from .unet import UNet, UNetHParams, UNetBlockHParams +from .simple_unet import SimpleUNet, SimpleUNetHParams diff --git a/src/lightning_trainable/modules/convolutional/__init__.py b/src/lightning_trainable/modules/convolutional/__init__.py deleted file mode 100644 index 7111d1e..0000000 --- a/src/lightning_trainable/modules/convolutional/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .block import ConvolutionalBlock -from .block_hparams import ConvolutionalBlockHParams - -from .network import ConvolutionalNetwork -from .hparams import ConvolutionalNetworkHParams diff --git a/src/lightning_trainable/modules/convolutional/block.py b/src/lightning_trainable/modules/convolutional/block.py deleted file mode 100644 index e83f743..0000000 --- a/src/lightning_trainable/modules/convolutional/block.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -import torch.nn as nn - -import lightning_trainable.utils as utils - -from ..hparams_module import HParamsModule - -from .block_hparams import ConvolutionalBlockHParams - - -class ConvolutionalBlock(HParamsModule): - """ - Implements a series of convolutions, each followed by an activation function. - """ - - hparams: ConvolutionalBlockHParams - - def __init__(self, hparams: dict | ConvolutionalBlockHParams): - super().__init__(hparams) - - self.network = self.configure_network() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.network(x) - - def configure_network(self): - # construct convolutions - convolutions = [] - cck = zip(self.hparams.channels[:-1], self.hparams.channels[1:], self.hparams.kernel_sizes) - for in_channels, out_channels, kernel_size in cck: - conv = nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - padding=self.hparams.padding, - dilation=self.hparams.dilation, - groups=self.hparams.groups, - bias=self.hparams.bias, - padding_mode=self.hparams.padding_mode, - ) - - convolutions.append(conv) - - # construct activations - activations = [] - for _ in range(len(self.hparams.channels) - 2): - activations.append(utils.get_activation(self.hparams.activation)(inplace=True)) - - layers = list(utils.zip(convolutions, activations, exhaustive=True, nested=False)) - - # add pooling layer if requested - if self.hparams.pool: - match self.hparams.pool_direction: - case "up": - if self.hparams.pool_position == "first": - channels = self.hparams.channels[0] - else: - channels = self.hparams.channels[-1] - - pool = nn.ConvTranspose2d(channels, channels, 2, 2) - case "down": - pool = nn.MaxPool2d(2, 2) - case _: - raise NotImplementedError(f"Unrecognized pool direction '{self.hparams.pool_direction}'.") - - match self.hparams.pool_position: - case "first": - layers.insert(0, pool) - case "last": - layers.append(pool) - case _: - raise NotImplementedError(f"Unrecognized pool position '{self.hparams.pool_position}'.") - - return nn.Sequential(*layers) diff --git a/src/lightning_trainable/modules/convolutional/block_hparams.py b/src/lightning_trainable/modules/convolutional/block_hparams.py deleted file mode 100644 index b901ea9..0000000 --- a/src/lightning_trainable/modules/convolutional/block_hparams.py +++ /dev/null @@ -1,25 +0,0 @@ -from lightning_trainable.hparams import HParams, Choice - - -class ConvolutionalBlockHParams(HParams): - channels: list[int] - kernel_sizes: list[int] - activation: str = "relu" - padding: str | int = 0 - dilation: int = 1 - groups: int = 1 - bias: bool = True - padding_mode: str = "zeros" - - pool: bool = False - pool_direction: Choice("up", "down") = "down" - pool_position: Choice("first", "last") = "last" - - @classmethod - def validate_parameters(cls, hparams): - hparams = super().validate_parameters(hparams) - - if not len(hparams.channels) == len(hparams.kernel_sizes): - raise ValueError(f"{cls.__name__} needs same number of channels and kernel sizes.") - - return hparams diff --git a/src/lightning_trainable/modules/convolutional/hparams.py b/src/lightning_trainable/modules/convolutional/hparams.py deleted file mode 100644 index 046e54e..0000000 --- a/src/lightning_trainable/modules/convolutional/hparams.py +++ /dev/null @@ -1,8 +0,0 @@ - -from lightning_trainable.hparams import HParams - -from .block_hparams import ConvolutionalBlockHParams - - -class ConvolutionalNetworkHParams(HParams): - block_hparams: list[ConvolutionalBlockHParams] diff --git a/src/lightning_trainable/modules/convolutional/network.py b/src/lightning_trainable/modules/convolutional/network.py deleted file mode 100644 index c4ad3ad..0000000 --- a/src/lightning_trainable/modules/convolutional/network.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch -import torch.nn as nn - - -from ..hparams_module import HParamsModule -from ..sequential_mixin import SequentialMixin - -from .block import ConvolutionalBlock -from .hparams import ConvolutionalNetworkHParams - - -class ConvolutionalNetwork(SequentialMixin, HParamsModule): - """ - Implements a series of pooled convolutional blocks. - """ - hparams: ConvolutionalNetworkHParams - - def __init__(self, hparams): - super().__init__(hparams) - self.network = self.configure_network() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.network(x) - - def configure_network(self): - blocks = [ConvolutionalBlock(hparams) for hparams in self.hparams.block_hparams] - - return nn.Sequential(*blocks) diff --git a/src/lightning_trainable/modules/fully_connected/hparams.py b/src/lightning_trainable/modules/fully_connected/hparams.py index 2d3dccb..7a981d1 100644 --- a/src/lightning_trainable/modules/fully_connected/hparams.py +++ b/src/lightning_trainable/modules/fully_connected/hparams.py @@ -8,3 +8,6 @@ class FullyConnectedNetworkHParams(HParams): layer_widths: list[int] activation: str = "relu" + + norm: Choice("none", "batch", "layer") = "none" + dropout: float = 0.0 diff --git a/src/lightning_trainable/modules/fully_connected/network.py b/src/lightning_trainable/modules/fully_connected/network.py index d1b0583..b4a133e 100644 --- a/src/lightning_trainable/modules/fully_connected/network.py +++ b/src/lightning_trainable/modules/fully_connected/network.py @@ -1,4 +1,6 @@ -import torch + +from torch import Tensor + import torch.nn as nn from lightning_trainable.utils import get_activation @@ -19,22 +21,63 @@ def __init__(self, hparams): self.network = self.configure_network() - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: return self.network(x) def configure_network(self): - layers = [] - widths = [self.hparams.input_dims, *self.hparams.layer_widths, self.hparams.output_dims] + widths = self.hparams.layer_widths + + # input layer + input_linear = self.configure_linear(self.hparams.input_dims, widths[0]) + input_activation = self.configure_activation() + layers = [input_linear, input_activation] + + # hidden layers + for (in_features, out_features) in zip(widths[:-1], widths[1:]): + dropout = self.configure_dropout() + norm = self.configure_norm(in_features) + + activation = self.configure_activation() + linear = self.configure_linear(in_features, out_features) - if self.hparams.input_dims == "lazy": - widths = widths[1:] - layers = [nn.LazyLinear(widths[0]), get_activation(self.hparams.activation)(inplace=True)] + if dropout is not None: + layers.append(dropout) - for (w1, w2) in zip(widths[:-1], widths[1:]): - layers.append(nn.Linear(w1, w2)) - layers.append(get_activation(self.hparams.activation)(inplace=True)) + if norm is not None: + layers.append(norm) - # remove last activation - layers = layers[:-1] + layers.append(linear) + layers.append(activation) + + # output layer + output_linear = self.configure_linear(widths[-1], self.hparams.output_dims) + layers.append(output_linear) return nn.Sequential(*layers) + + def configure_linear(self, in_features, out_features) -> nn.Module: + match in_features: + case "lazy": + return nn.LazyLinear(out_features) + case int() as in_features: + return nn.Linear(in_features, out_features) + case other: + raise NotImplementedError(f"Unrecognized input_dims value: '{other}'") + + def configure_activation(self) -> nn.Module: + return get_activation(self.hparams.activation)(inplace=True) + + def configure_dropout(self) -> nn.Module | None: + if self.hparams.dropout > 0: + return nn.Dropout(self.hparams.dropout) + + def configure_norm(self, num_features) -> nn.Module | None: + match self.hparams.norm: + case "none": + return None + case "batch": + return nn.BatchNorm1d(num_features) + case "layer": + return nn.LayerNorm(num_features) + case other: + raise NotImplementedError(f"Unrecognized norm value: '{other}'") diff --git a/src/lightning_trainable/modules/hparams_module.py b/src/lightning_trainable/modules/hparams_module.py index 225ce17..d8a0691 100644 --- a/src/lightning_trainable/modules/hparams_module.py +++ b/src/lightning_trainable/modules/hparams_module.py @@ -17,4 +17,7 @@ def __init__(self, hparams: HParams | dict): self.hparams = hparams def __init_subclass__(cls, **kwargs): - cls.hparams_type = cls.__annotations__["hparams"] + hparams_type = cls.__annotations__.get("hparams") + if hparams_type is not None: + # only overwrite hparams_type if it is defined by the child class + cls.hparams_type = hparams_type diff --git a/src/lightning_trainable/modules/simple_unet/__init__.py b/src/lightning_trainable/modules/simple_unet/__init__.py new file mode 100644 index 0000000..64c06ad --- /dev/null +++ b/src/lightning_trainable/modules/simple_unet/__init__.py @@ -0,0 +1,3 @@ + +from .hparams import SimpleUNetHParams +from .network import SimpleUNet diff --git a/src/lightning_trainable/modules/simple_unet/down_block.py b/src/lightning_trainable/modules/simple_unet/down_block.py new file mode 100644 index 0000000..3a3cd4f --- /dev/null +++ b/src/lightning_trainable/modules/simple_unet/down_block.py @@ -0,0 +1,29 @@ + +import torch +import torch.nn as nn +from torch.nn import Module + +from lightning_trainable.utils import get_activation + + +class SimpleUNetDownBlock(Module): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int, block_size: int = 2, activation: str = "relu"): + super().__init__() + + self.channels = torch.linspace(in_channels, out_channels, block_size + 1, dtype=torch.int64).tolist() + + layers = [] + for c1, c2 in zip(self.channels[:-2], self.channels[1:-1]): + layers.append(nn.Conv2d(c1, c2, kernel_size, padding="same")) + layers.append(get_activation(activation)(inplace=True)) + + layers.append(nn.Conv2d(self.channels[-2], self.channels[-1], kernel_size, padding="same")) + layers.append(nn.MaxPool2d(2)) + + self.block = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.block(x) + + def extra_repr(self) -> str: + return f"in_channels={self.channels[0]}, out_channels={self.channels[-1]}, kernel_size={self.block[0].kernel_size[0]}, block_size={len(self.channels) - 1}" diff --git a/src/lightning_trainable/modules/simple_unet/hparams.py b/src/lightning_trainable/modules/simple_unet/hparams.py new file mode 100644 index 0000000..bb4714d --- /dev/null +++ b/src/lightning_trainable/modules/simple_unet/hparams.py @@ -0,0 +1,23 @@ +from lightning_trainable.hparams import HParams + + +class SimpleUNetHParams(HParams): + input_shape: tuple[int, int, int] + conditions: int = 0 + + channels: list[int] + kernel_sizes: list[int] + fc_widths: list[int] + activation: str = "ReLU" + + block_size: int = 2 + + @classmethod + def validate_parameters(cls, hparams): + hparams = super().validate_parameters(hparams) + + if len(hparams.channels) + 2 != len(hparams.kernel_sizes): + raise ValueError(f"Number of channels ({len(hparams.channels)}) + 2 must be equal " + f"to the number of kernel sizes ({len(hparams.kernel_sizes)})") + + return hparams diff --git a/src/lightning_trainable/modules/simple_unet/network.py b/src/lightning_trainable/modules/simple_unet/network.py new file mode 100644 index 0000000..f2fce9a --- /dev/null +++ b/src/lightning_trainable/modules/simple_unet/network.py @@ -0,0 +1,73 @@ + +import torch +import torch.nn as nn +from torch import Tensor + +from ..fully_connected import FullyConnectedNetwork +from ..hparams_module import HParamsModule + +from .down_block import SimpleUNetDownBlock +from .hparams import SimpleUNetHParams +from .up_block import SimpleUNetUpBlock + + +class SimpleUNet(HParamsModule): + hparams: SimpleUNetHParams + + def __init__(self, hparams: dict | SimpleUNetHParams): + super().__init__(hparams) + + channels = [self.hparams.input_shape[0], *self.hparams.channels] + + self.down_blocks = nn.ModuleList([ + SimpleUNetDownBlock(c1, c2, kernel_size, self.hparams.block_size, self.hparams.activation) + for c1, c2, kernel_size in zip(channels[:-1], channels[1:], self.hparams.kernel_sizes) + ]) + + fc_channels = self.hparams.channels[-1] + height = self.hparams.input_shape[1] // 2 ** (len(channels) - 1) + width = self.hparams.input_shape[2] // 2 ** (len(channels) - 1) + + fc_hparams = dict( + input_dims=self.hparams.conditions + fc_channels * height * width, + output_dims=fc_channels * height * width, + activation=self.hparams.activation, + layer_widths=self.hparams.fc_widths, + ) + self.fc = FullyConnectedNetwork(fc_hparams) + self.up_blocks = nn.ModuleList([ + SimpleUNetUpBlock(c2, c1, kernel_size, self.hparams.block_size, self.hparams.activation) + for c1, c2, kernel_size in zip(channels[:-1], channels[1:], self.hparams.kernel_sizes) + ][::-1]) + + def forward(self, image: Tensor, condition: Tensor = None) -> Tensor: + residuals = [] + for block in self.down_blocks: + image = block(image) + residuals.append(image) + + shape = image.shape + image = image.flatten(start_dim=1) + + if condition is not None: + image = torch.cat([image, condition], dim=1) + + image = self.fc(image) + + image = image.reshape(shape) + + for block in self.up_blocks: + residual = residuals.pop() + image = block(image + residual) + + return image + + def down(self, image: Tensor) -> Tensor: + for block in self.down_blocks: + image = block(image) + return image + + def up(self, image: Tensor) -> Tensor: + for block in self.up_blocks: + image = block(image) + return image diff --git a/src/lightning_trainable/modules/simple_unet/up_block.py b/src/lightning_trainable/modules/simple_unet/up_block.py new file mode 100644 index 0000000..b389626 --- /dev/null +++ b/src/lightning_trainable/modules/simple_unet/up_block.py @@ -0,0 +1,29 @@ + +import torch +import torch.nn as nn +from torch.nn import Module + +from lightning_trainable.utils import get_activation + + +class SimpleUNetUpBlock(Module): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int, block_size: int = 2, activation: str = "relu"): + super().__init__() + + self.channels = torch.linspace(in_channels, out_channels, block_size + 1, dtype=torch.int64).tolist() + + layers = [] + for c1, c2 in zip(self.channels[:-2], self.channels[1:-1]): + layers.append(nn.Conv2d(c1, c2, kernel_size, padding="same")) + layers.append(get_activation(activation)(inplace=True)) + + layers.append(nn.Conv2d(self.channels[-2], self.channels[-1], kernel_size, padding="same")) + layers.append(nn.ConvTranspose2d(self.channels[-1], self.channels[-1], 2, 2)) + + self.block = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.block(x) + + def extra_repr(self) -> str: + return f"in_channels={self.channels[0]}, out_channels={self.channels[-1]}, kernel_size={self.block[0].kernel_size[0]}, block_size={len(self.channels) - 1}" diff --git a/src/lightning_trainable/modules/unet/__init__.py b/src/lightning_trainable/modules/unet/__init__.py deleted file mode 100644 index 964f0cb..0000000 --- a/src/lightning_trainable/modules/unet/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .hparams import UNetHParams, UNetBlockHParams -from .network import UNet diff --git a/src/lightning_trainable/modules/unet/hparams.py b/src/lightning_trainable/modules/unet/hparams.py deleted file mode 100644 index 6afc972..0000000 --- a/src/lightning_trainable/modules/unet/hparams.py +++ /dev/null @@ -1,66 +0,0 @@ -from lightning_trainable.hparams import HParams, Choice - - -class UNetBlockHParams(HParams): - # this is a reduced variant of the ConvolutionalBlockHParams - # where additional parameters are generated automatically by UNet - channels: list[int] - kernel_sizes: list[int] - - @classmethod - def validate_parameters(cls, hparams): - hparams = super().validate_parameters(hparams) - - if len(hparams.channels) + 1 != len(hparams.kernel_sizes): - raise ValueError(f"Block needs one more kernel size than channels.") - - return hparams - - -class UNetHParams(HParams): - # input/output image size, not including batch dimensions - input_shape: tuple[int, int, int] - output_shape: tuple[int, int, int] - - # list of hparams for individual down/up blocks - down_blocks: list[dict | UNetBlockHParams] - up_blocks: list[dict | UNetBlockHParams] - - # hidden layer sizes for the bottom, fully connected part of the UNet - bottom_widths: list[int] - - # skip connection mode - skip_mode: Choice("add", "concat", "none") = "add" - activation: str = "relu" - - @classmethod - def validate_parameters(cls, hparams): - hparams = super().validate_parameters(hparams) - - for i in range(len(hparams.down_blocks)): - hparams.down_blocks[i] = UNetBlockHParams(**hparams.down_blocks[i]) - for i in range(len(hparams.up_blocks)): - hparams.up_blocks[i] = UNetBlockHParams(**hparams.up_blocks[i]) - - url = "https://github.com/LarsKue/lightning-trainable/" - if hparams.input_shape[1:] != hparams.output_shape[1:]: - raise ValueError(f"Different image sizes for input and output are not yet supported. " - f"If you need this feature, please file an issue or pull request at {url}.") - - if hparams.input_shape[1] % 2 or hparams.input_shape[2] % 2: - raise ValueError(f"Odd input shape is not yet supported. " - f"If you need this feature, please file an issue or pull request at {url}.") - - minimum_size = 2 ** len(hparams.down_blocks) - if hparams.input_shape[1] < minimum_size or hparams.input_shape[2] < minimum_size: - raise ValueError(f"Input shape {hparams.input_shape[1:]} is too small for {len(hparams.down_blocks)} " - f"down blocks. Minimum size is {(minimum_size, minimum_size)}.") - - if hparams.skip_mode == "add": - # ensure matching number of channels for down output as up input - for i, (down_block, up_block) in enumerate(zip(hparams.down_blocks, reversed(hparams.up_blocks))): - if down_block["channels"][-1] != up_block["channels"][0]: - raise ValueError(f"Output channels of down block {i} must match input channels of up block " - f"{len(hparams.up_blocks) - (i + 1)} for skip mode '{hparams.skip_mode}'.") - - return hparams diff --git a/src/lightning_trainable/modules/unet/network.py b/src/lightning_trainable/modules/unet/network.py deleted file mode 100644 index f40a31a..0000000 --- a/src/lightning_trainable/modules/unet/network.py +++ /dev/null @@ -1,100 +0,0 @@ -import torch -import torch.nn as nn - -import math - -from lightning_trainable.hparams import AttributeDict - -from ..convolutional import ConvolutionalBlock -from ..fully_connected import FullyConnectedNetwork -from ..hparams_module import HParamsModule - -from .hparams import UNetHParams -from .skip_connection import SkipConnection -from .temporary_flatten import TemporaryFlatten - - -class UNet(HParamsModule): - hparams: UNetHParams - - def __init__(self, hparams: dict | UNetHParams): - super().__init__(hparams) - - self.network = self.configure_network( - input_shape=self.hparams.input_shape, - output_shape=self.hparams.output_shape, - down_blocks=self.hparams.down_blocks, - up_blocks=self.hparams.up_blocks, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.network(x) - - def configure_levels(self, input_shape: (int, int, int), output_shape: (int, int, int), down_blocks: list[dict], up_blocks: list[dict]): - """ - Recursively configures the levels of the UNet. - """ - if not down_blocks: - return self.configure_fc(input_shape, output_shape) - - down_block = down_blocks[0] - up_block = up_blocks[-1] - - down_hparams = AttributeDict( - channels=[input_shape[0], *down_block["channels"]], - kernel_sizes=down_block["kernel_sizes"], - activation=self.hparams.activation, - padding="same", - pool=True, - pool_direction="down", - pool_position="last" - ) - up_hparams = AttributeDict( - channels=[*up_block["channels"], output_shape[0]], - kernel_sizes=up_block["kernel_sizes"], - activation=self.hparams.activation, - padding="same", - pool=True, - pool_direction="up", - pool_position="first" - ) - - next_input_shape = (down_hparams.channels[-1], input_shape[1] // 2, input_shape[2] // 2) - next_output_shape = (up_hparams.channels[0], output_shape[1] // 2, output_shape[2] // 2) - - if self.hparams.skip_mode == "concat": - up_hparams.channels[0] += down_hparams.channels[-1] - - down_block = ConvolutionalBlock(down_hparams) - up_block = ConvolutionalBlock(up_hparams) - - next_level = self.configure_levels( - input_shape=next_input_shape, - output_shape=next_output_shape, - down_blocks=down_blocks[1:], - up_blocks=up_blocks[:-1], - ) - - return nn.Sequential( - down_block, - SkipConnection(next_level, mode=self.hparams.skip_mode), - up_block, - ) - - def configure_fc(self, input_shape: (int, int, int), output_shape: (int, int, int)): - """ - Configures the lowest level of the UNet as a fully connected network. - """ - hparams = dict( - input_dims=math.prod(input_shape), - output_dims=math.prod(output_shape), - layer_widths=self.hparams.bottom_widths, - activation=self.hparams.activation, - ) - return TemporaryFlatten(FullyConnectedNetwork(hparams), input_shape, output_shape) - - def configure_network(self, input_shape: (int, int, int), output_shape: (int, int, int), down_blocks: list[dict], up_blocks: list[dict]): - """ - Configures the UNet. - """ - return self.configure_levels(input_shape, output_shape, down_blocks, up_blocks) diff --git a/src/lightning_trainable/modules/unet/skip_connection.py b/src/lightning_trainable/modules/unet/skip_connection.py deleted file mode 100644 index c6ac170..0000000 --- a/src/lightning_trainable/modules/unet/skip_connection.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch -import torch.nn as nn - - -class SkipConnection(nn.Module): - def __init__(self, inner: nn.Module, mode: str = "add"): - super().__init__() - self.inner = inner - self.mode = mode - - def forward(self, x: torch.Tensor) -> torch.Tensor: - match self.mode: - case "add": - return self.inner(x) + x - case "concat": - return torch.cat((self.inner(x), x), dim=1) - case "none": - return self.inner(x) - case other: - raise NotImplementedError(f"Unrecognized skip connection mode '{other}'.") - - def extra_repr(self) -> str: - return f"mode={self.mode}" diff --git a/src/lightning_trainable/modules/unet/temporary_flatten.py b/src/lightning_trainable/modules/unet/temporary_flatten.py deleted file mode 100644 index 4b23244..0000000 --- a/src/lightning_trainable/modules/unet/temporary_flatten.py +++ /dev/null @@ -1,14 +0,0 @@ -import torch -import torch.nn as nn - - -class TemporaryFlatten(nn.Module): - def __init__(self, inner: nn.Module, input_shape, output_shape): - super().__init__() - self.inner = inner - self.input_shape = input_shape - self.output_shape = output_shape - - def forward(self, x: torch.Tensor) -> torch.Tensor: - out = self.inner(x.flatten(1)) - return out.reshape(x.shape[0], *self.output_shape) diff --git a/src/lightning_trainable/trainable/lr_schedulers/__init__.py b/src/lightning_trainable/trainable/lr_schedulers/__init__.py new file mode 100644 index 0000000..9234965 --- /dev/null +++ b/src/lightning_trainable/trainable/lr_schedulers/__init__.py @@ -0,0 +1 @@ +from .configure import configure diff --git a/src/lightning_trainable/trainable/lr_schedulers/configure.py b/src/lightning_trainable/trainable/lr_schedulers/configure.py new file mode 100644 index 0000000..89c0307 --- /dev/null +++ b/src/lightning_trainable/trainable/lr_schedulers/configure.py @@ -0,0 +1,40 @@ + +from torch.optim.lr_scheduler import LRScheduler + +from lightning_trainable.utils import get_scheduler + +from . import defaults + + +def with_kwargs(model, optimizer, **kwargs) -> LRScheduler: + """ + Get a learning rate scheduler with the given kwargs + Insert default values for missing kwargs + @param model: Trainable model + @param optimizer: The optimizer to use with the scheduler + @param kwargs: Keyword arguments for the scheduler + @return: The configured learning rate scheduler + """ + name = kwargs.pop("name") + default_kwargs = defaults.get_defaults(name, model, optimizer) + kwargs = default_kwargs | kwargs + return get_scheduler(name)(optimizer, **kwargs) + + +def configure(model, optimizer) -> LRScheduler | None: + """ + Configure a learning rate scheduler from the model's hparams + @param model: Trainable model + @param optimizer: The optimizer to use with the scheduler + @return: The configured learning rate scheduler + """ + match model.hparams.lr_scheduler: + case str() as name: + return with_kwargs(model, optimizer, name=name) + case dict() as kwargs: + return with_kwargs(model, optimizer, **kwargs.copy()) + case None: + # do not use a learning rate scheduler + return None + case other: + raise NotImplementedError(f"Unrecognized Scheduler: '{other}'") diff --git a/src/lightning_trainable/trainable/lr_schedulers/defaults.py b/src/lightning_trainable/trainable/lr_schedulers/defaults.py new file mode 100644 index 0000000..18a672f --- /dev/null +++ b/src/lightning_trainable/trainable/lr_schedulers/defaults.py @@ -0,0 +1,16 @@ + +def get_defaults(scheduler_name, model, optimizer): + match scheduler_name: + case "OneCycleLR": + max_lr = optimizer.defaults["lr"] + total_steps = model.hparams.max_steps + if total_steps == -1: + total_steps = model.hparams.max_epochs * len(model.train_dataloader()) + total_steps = int(total_steps / model.hparams.accumulate_batches) + + return dict( + max_lr=max_lr, + total_steps=total_steps + ) + case _: + return dict() diff --git a/src/lightning_trainable/trainable/optimizers/__init__.py b/src/lightning_trainable/trainable/optimizers/__init__.py new file mode 100644 index 0000000..9234965 --- /dev/null +++ b/src/lightning_trainable/trainable/optimizers/__init__.py @@ -0,0 +1 @@ +from .configure import configure diff --git a/src/lightning_trainable/trainable/optimizers/configure.py b/src/lightning_trainable/trainable/optimizers/configure.py new file mode 100644 index 0000000..3828d13 --- /dev/null +++ b/src/lightning_trainable/trainable/optimizers/configure.py @@ -0,0 +1,38 @@ + +from torch.optim import Optimizer + +from lightning_trainable.utils import get_optimizer + +from . import defaults + + +def with_kwargs(model, **kwargs) -> Optimizer: + """ + Get an optimizer with the given kwargs + Insert default values for missing kwargs + @param model: Trainable model + @param kwargs: Keyword arguments for the optimizer + @return: The configured optimizer + """ + name = kwargs.pop("name") + default_kwargs = defaults.get_defaults(name, model) + kwargs = default_kwargs | kwargs + return get_optimizer(name)(**kwargs) + + +def configure(model) -> Optimizer | None: + """ + Configure an optimizer from the model's hparams + @param model: Trainable model + @return: The configured optimizer + """ + match model.hparams.optimizer: + case str() as name: + return with_kwargs(model, name=name) + case dict() as kwargs: + return with_kwargs(model, **kwargs.copy()) + case None: + # do not use an optimizer + return None + case other: + raise NotImplementedError(f"Unrecognized Optimizer: '{other}'") diff --git a/src/lightning_trainable/trainable/optimizers/defaults.py b/src/lightning_trainable/trainable/optimizers/defaults.py new file mode 100644 index 0000000..5b9b23c --- /dev/null +++ b/src/lightning_trainable/trainable/optimizers/defaults.py @@ -0,0 +1,7 @@ + +def get_defaults(optimizer_name, model): + match optimizer_name: + case _: + return dict( + params=model.parameters(), + ) diff --git a/src/lightning_trainable/trainable/trainable.py b/src/lightning_trainable/trainable/trainable.py index 210a7cb..9684043 100644 --- a/src/lightning_trainable/trainable/trainable.py +++ b/src/lightning_trainable/trainable/trainable.py @@ -1,18 +1,19 @@ -import os -import pathlib -from copy import deepcopy - import lightning +import os import torch -from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, ProgressBar, EarlyStopping + +from copy import deepcopy +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping from lightning.pytorch.loggers import Logger, TensorBoardLogger from torch.utils.data import DataLoader, Dataset, IterableDataset from tqdm import tqdm from lightning_trainable import utils -from lightning_trainable.callbacks import EpochProgressBar from .trainable_hparams import TrainableHParams +from . import lr_schedulers +from . import optimizers + class SkipBatch(Exception): pass @@ -33,15 +34,16 @@ def __init__( if not isinstance(hparams, self.hparams_type): hparams = self.hparams_type(**hparams) self.save_hyperparameters(hparams) - # workaround for https://github.com/Lightning-AI/lightning/issues/17889 - self._hparams_name = "hparams" self.train_data = train_data self.val_data = val_data self.test_data = test_data def __init_subclass__(cls, **kwargs): - cls.hparams_type = cls.__annotations__.get("hparams", TrainableHParams) + hparams_type = cls.__annotations__.get("hparams") + if hparams_type is not None: + # only overwrite hparams_type if it is defined by the child class + cls.hparams_type = hparams_type def compute_metrics(self, batch, batch_idx) -> dict: """ @@ -89,97 +91,18 @@ def test_step(self, batch, batch_idx): self.log(f"test/{key}", value, prog_bar=key == self.hparams.loss) - def configure_lr_schedulers(self, optimizer): - """ - Configure the LR Scheduler as defined in HParams. - By default, we only use a single LR Scheduler, attached to a single optimizer. - You can use a ChainedScheduler if you need multiple LR Schedulers throughout training, - or override this method if you need different schedulers for different parameters. - - @param optimizer: The optimizer to attach the scheduler to. - @return: The LR Scheduler object. - """ - match self.hparams.lr_scheduler: - case str() as name: - match name.lower(): - case "onecyclelr": - kwargs = dict( - max_lr=optimizer.defaults["lr"], - epochs=self.hparams.max_epochs, - steps_per_epoch=len(self.train_dataloader()) - ) - interval = "step" - case _: - kwargs = dict() - interval = "step" - scheduler = utils.get_scheduler(name)(optimizer, **kwargs) - return dict( - scheduler=scheduler, - interval=interval, - ) - case dict() as kwargs: - # Copy the dict so we don't modify the original - kwargs = deepcopy(kwargs) - - name = kwargs.pop("name") - interval = "step" - if "interval" in kwargs: - interval = kwargs.pop("interval") - config_kwargs = kwargs.pop("config", dict()) - scheduler = utils.get_scheduler(name)(optimizer, **kwargs) - return dict( - scheduler=scheduler, - interval=interval, - **config_kwargs - ) - case type(torch.optim.lr_scheduler.LRScheduler) as Scheduler: - kwargs = dict() - interval = "step" - scheduler = Scheduler(optimizer, **kwargs) - return dict( - scheduler=scheduler, - interval=interval, - ) - case (torch.optim.lr_scheduler.LRScheduler() | torch.optim.lr_scheduler.ReduceLROnPlateau()) as scheduler: - return dict( - scheduler=scheduler, - interval="step", - ) - case None: - # do not use a scheduler - return None - case other: - raise NotImplementedError(f"Unrecognized Scheduler: {other}") - def configure_optimizers(self): """ - Configure Optimizer and LR Scheduler objects as defined in HParams. - By default, we only use a single optimizer and an optional LR Scheduler. - If you need multiple optimizers, override this method. + Configure the optimizer and learning rate scheduler for this model, based on the HParams. + This method is called automatically by the Lightning Trainer in module fitting. - @return: A dictionary containing the optimizer and lr_scheduler. - """ - kwargs = dict() - - match self.hparams.optimizer: - case str() as name: - optimizer = utils.get_optimizer(name)(self.parameters(), **kwargs) - case dict() as kwargs: - # Copy the dict so we don't modify the original - kwargs = deepcopy(kwargs) - - name = kwargs.pop("name") - optimizer = utils.get_optimizer(name)(self.parameters(), **kwargs) - case type(torch.optim.Optimizer) as Optimizer: - optimizer = Optimizer(self.parameters(), **kwargs) - case torch.optim.Optimizer() as optimizer: - pass - case None: - return None - case other: - raise NotImplementedError(f"Unrecognized Optimizer: {other}") + By default, we use one optimizer and zero or one learning rate scheduler. + If you want to use multiple optimizers or learning rate schedulers, you must override this method. - lr_scheduler = self.configure_lr_schedulers(optimizer) + @return: A dictionary containing the optimizer and learning rate scheduler, if any. + """ + optimizer = optimizers.configure(self) + lr_scheduler = lr_schedulers.configure(self, optimizer) if lr_scheduler is None: return optimizer @@ -207,7 +130,6 @@ def configure_callbacks(self) -> list: callbacks = [ ModelCheckpoint(**checkpoint_kwargs), LearningRateMonitor(), - EpochProgressBar(), ] if self.hparams.early_stopping is not None: if self.hparams.early_stopping["monitor"] == "auto": @@ -263,28 +185,22 @@ def test_dataloader(self) -> DataLoader | list[DataLoader]: num_workers=self.hparams.num_workers, ) - def configure_logger(self, save_dir=os.getcwd(), **kwargs) -> Logger: + def configure_logger(self, logger_name: str = "TensorBoardLogger", **logger_kwargs) -> Logger: """ Instantiate the Logger used by the Trainer in module fitting. By default, we use a TensorBoardLogger, but you can use any other logger of your choice. - @param save_dir: The root directory in which all your experiments with - different names and versions will be stored. - @param kwargs: Keyword-Arguments to the Logger. Set `logger_name` to use a different logger than TensorBoardLogger. + @param logger_name: The name of the logger to use. Defaults to TensorBoardLogger. + @param logger_kwargs: Keyword-Arguments to the Logger. Set `logger_name` to use a different logger than TensorBoardLogger. @return: The Logger object. """ - logger_kwargs = deepcopy(kwargs) - logger_kwargs.update(dict( - save_dir=save_dir, - )) - logger_name = logger_kwargs.pop("logger_name", "TensorBoardLogger") + logger_kwargs = logger_kwargs or {} + logger_kwargs.setdefault("save_dir", os.getcwd()) logger_class = utils.get_logger(logger_name) if issubclass(logger_class, TensorBoardLogger): - logger_kwargs["default_hp_metric"] = False + logger_kwargs.setdefault("default_hp_metric", False) - return logger_class( - **logger_kwargs - ) + return logger_class(**logger_kwargs) def configure_trainer(self, logger_kwargs: dict = None, trainer_kwargs: dict = None) -> lightning.Trainer: """ @@ -296,27 +212,20 @@ def configure_trainer(self, logger_kwargs: dict = None, trainer_kwargs: dict = N See also :func:`~trainable.Trainable.configure_trainer`. @return: The Lightning Trainer object. """ - if logger_kwargs is None: - logger_kwargs = dict() - if trainer_kwargs is None: - trainer_kwargs = dict() - - if "enable_progress_bar" not in trainer_kwargs: - if any(isinstance(callback, ProgressBar) for callback in self.configure_callbacks()): - trainer_kwargs["enable_progress_bar"] = False - - return lightning.Trainer( - accelerator=self.hparams.accelerator.lower(), - logger=self.configure_logger(**logger_kwargs), - devices=self.hparams.devices, - max_epochs=self.hparams.max_epochs, - max_steps=self.hparams.max_steps, - gradient_clip_val=self.hparams.gradient_clip, - accumulate_grad_batches=self.hparams.accumulate_batches, - profiler=self.hparams.profiler, - benchmark=True, - **trainer_kwargs, - ) + logger_kwargs = logger_kwargs or {} + trainer_kwargs = trainer_kwargs or {} + + trainer_kwargs.setdefault("accelerator", self.hparams.accelerator.lower()) + trainer_kwargs.setdefault("accumulate_grad_batches", self.hparams.accumulate_batches) + trainer_kwargs.setdefault("benchmark", True) + trainer_kwargs.setdefault("devices", self.hparams.devices) + trainer_kwargs.setdefault("gradient_clip_val", self.hparams.gradient_clip) + trainer_kwargs.setdefault("logger", self.configure_logger(**logger_kwargs)) + trainer_kwargs.setdefault("max_epochs", self.hparams.max_epochs) + trainer_kwargs.setdefault("max_steps", self.hparams.max_steps) + trainer_kwargs.setdefault("profiler", self.hparams.profiler) + + return lightning.Trainer(**trainer_kwargs) def on_before_optimizer_step(self, optimizer): # who doesn't love breaking changes in underlying libraries @@ -329,6 +238,14 @@ def on_before_optimizer_step(self, optimizer): case other: raise NotImplementedError(f"Unrecognized grad norm: {other}") + def on_train_start(self) -> None: + # get hparams metrics with a test batch + test_batch = next(iter(self.trainer.train_dataloader)) + metrics = self.compute_metrics(test_batch, 0) + + # add hparams to tensorboard + self.logger.log_hyperparams(self.hparams, metrics) + @torch.enable_grad() def fit(self, logger_kwargs: dict = None, trainer_kwargs: dict = None, fit_kwargs: dict = None) -> dict: """ @@ -341,26 +258,18 @@ def fit(self, logger_kwargs: dict = None, trainer_kwargs: dict = None, fit_kwarg @param fit_kwargs: Keyword-Arguments to the Trainer's fit method. @return: Validation Metrics as defined in :func:`~trainable.Trainable.compute_metrics`. """ - if logger_kwargs is None: - logger_kwargs = dict() - if trainer_kwargs is None: - trainer_kwargs = dict() - if fit_kwargs is None: - fit_kwargs = dict() + logger_kwargs = logger_kwargs or {} + trainer_kwargs = trainer_kwargs or {} + fit_kwargs = fit_kwargs or {} trainer = self.configure_trainer(logger_kwargs, trainer_kwargs) - metrics_list = trainer.validate(self) - if metrics_list is not None and len(metrics_list) > 0: - metrics = metrics_list[0] - else: - metrics = {} - trainer.logger.log_hyperparams(self.hparams, metrics) + trainer.fit(self, **fit_kwargs) return { key: value.item() for key, value in trainer.callback_metrics.items() - if any(key.startswith(key) for key in ["training/", "validation/"]) + if any(key.startswith(k) for k in ["training/", "validation/"]) } @torch.enable_grad() @@ -378,7 +287,14 @@ def fit_fast(self, device="cuda"): self.train() self.to(device) - optimizer = self.configure_optimizers()["optimizer"] + maybe_optimizer = self.configure_optimizers() + if isinstance(maybe_optimizer, dict): + optimizer = maybe_optimizer["optimizer"] + elif isinstance(maybe_optimizer, torch.optim.Optimizer): + optimizer = maybe_optimizer + else: + raise RuntimeError("Invalid optimizer") + dataloader = self.train_dataloader() loss = None @@ -387,18 +303,12 @@ def fit_fast(self, device="cuda"): batch = tuple(t.to(device) for t in batch if torch.is_tensor(t)) optimizer.zero_grad() - loss = self.training_step(batch, 0) + loss = self.training_step(batch, batch_idx) loss.backward() optimizer.step() return loss - @classmethod - def load_checkpoint(cls, root: str | pathlib.Path = "lightning_logs", version: int | str = "last", - epoch: int | str = "last", step: int | str = "last", **kwargs): - checkpoint = utils.find_checkpoint(root, version, epoch, step) - return cls.load_from_checkpoint(checkpoint, **kwargs) - def auto_pin_memory(pin_memory: bool | None, accelerator: str): if pin_memory is None: diff --git a/src/lightning_trainable/trainable/trainable_hparams.py b/src/lightning_trainable/trainable/trainable_hparams.py index 3b58994..d27a455 100644 --- a/src/lightning_trainable/trainable/trainable_hparams.py +++ b/src/lightning_trainable/trainable/trainable_hparams.py @@ -1,6 +1,7 @@ from lightning_trainable.hparams import HParams from lightning.pytorch.profilers import Profiler +from lightning_trainable.utils import deprecate class TrainableHParams(HParams): @@ -11,7 +12,7 @@ class TrainableHParams(HParams): devices: int = 1 max_epochs: int | None max_steps: int = -1 - optimizer: str | dict | None = "adam" + optimizer: str | dict | None = "Adam" lr_scheduler: str | dict | None = None batch_size: int accumulate_batches: int = 1 @@ -31,7 +32,37 @@ class TrainableHParams(HParams): @classmethod def _migrate_hparams(cls, hparams): if "accumulate_batches" in hparams and hparams["accumulate_batches"] is None: + deprecate("accumulate_batches changed default value: None -> 1") hparams["accumulate_batches"] = 1 + + if "optimizer" in hparams: + match hparams["optimizer"]: + case str() as name: + if name == name.lower(): + deprecate("optimizer name is now case-sensitive.") + if name == "adam": + hparams["optimizer"] = "Adam" + case dict() as kwargs: + name = kwargs["name"] + if name == name.lower(): + deprecate("optimizer name is now case-sensitive.") + if name == "adam": + hparams["optimizer"]["name"] = "Adam" + + if "lr_scheduler" in hparams: + match hparams["lr_scheduler"]: + case str() as name: + if name == name.lower(): + deprecate("lr_scheduler name is now case-sensitive.") + if name == "onecyclelr": + hparams["lr_scheduler"] = "OneCycleLR" + case dict() as kwargs: + name = kwargs["name"] + if name == name.lower(): + deprecate("lr_scheduler name is now case-sensitive.") + if name == "onecyclelr": + hparams["lr_scheduler"]["name"] = "OneCycleLR" + if "early_stopping" in hparams and isinstance(hparams["early_stopping"], int): hparams["early_stopping"] = dict(monitor="auto", patience=hparams["early_stopping"]) return hparams diff --git a/src/lightning_trainable/utils/__init__.py b/src/lightning_trainable/utils/__init__.py index 426dc0f..def6be2 100644 --- a/src/lightning_trainable/utils/__init__.py +++ b/src/lightning_trainable/utils/__init__.py @@ -1,3 +1,4 @@ +from .deprecate import deprecate from .io import find_checkpoint from .iteration import flatten, zip from .modules import ( diff --git a/src/lightning_trainable/utils/deprecate.py b/src/lightning_trainable/utils/deprecate.py new file mode 100644 index 0000000..ff592d4 --- /dev/null +++ b/src/lightning_trainable/utils/deprecate.py @@ -0,0 +1,33 @@ + +import warnings +from functools import wraps +import inspect + + +def deprecate(message: str, version: str = None): + + def wrapper(obj): + nonlocal message + + if inspect.isclass(obj): + name = "Class" + elif inspect.isfunction(obj): + name = "Function" + elif inspect.ismethod(obj): + name = "Method" + else: + name = "Object" + + if version is not None: + message = f"{name} '{obj.__name__}' is deprecated since version {version}: {message}" + else: + message = f"{name} '{obj.__name__}' is deprecated: {message}" + + @wraps(obj) + def wrapped(*args, **kwargs): + warnings.warn(message, DeprecationWarning, stacklevel=2) + return obj(*args, **kwargs) + + return wrapped + + return wrapper diff --git a/src/lightning_trainable/utils/io.py b/src/lightning_trainable/utils/io.py index 73a0931..b44a4c2 100644 --- a/src/lightning_trainable/utils/io.py +++ b/src/lightning_trainable/utils/io.py @@ -9,6 +9,10 @@ def find_version(root: str | Path = "lightning_logs", version: int = "last") -> # Determine latest version number if "last" is passed as version number if version == "last": version_folders = [f for f in root.iterdir() if f.is_dir() and re.match(r"^version_(\d+)$", f.name)] + + if not version_folders: + raise FileNotFoundError(f"Found no folders in '{root}' matching the name pattern 'version_'") + version_numbers = [int(re.match(r"^version_(\d+)$", f.name).group(1)) for f in version_folders] version = max(version_numbers) @@ -25,7 +29,6 @@ def find_epoch_step(root: str | Path, epoch: int = "last", step: int = "last") - @param step: Step number or "last" @return: epoch and step numbers """ - root = Path(root) # get checkpoint filenames @@ -36,6 +39,10 @@ def find_epoch_step(root: str | Path, epoch: int = "last", step: int = "last") - # remove invalid files checkpoints = [cp for cp in checkpoints if pattern.match(cp)] + if not checkpoints: + raise FileNotFoundError(f"Found no checkpoints in '{root}' matching " + f"the name pattern 'epoch=-step=.ckpt'") + # get epochs and steps as list matches = [pattern.match(cp) for cp in checkpoints] epochs, steps = zip(*[(int(match.group(1)), int(match.group(2))) for match in matches]) @@ -43,7 +50,8 @@ def find_epoch_step(root: str | Path, epoch: int = "last", step: int = "last") - if epoch == "last": epoch = max(epochs) elif epoch not in epochs: - raise FileNotFoundError(f"No checkpoint in '{root}' for epoch '{epoch}'.") + closest_epoch = min(epochs, key=lambda e: abs(e - epoch)) + raise FileNotFoundError(f"No checkpoint in '{root}' for epoch '{epoch}'. Closest is '{closest_epoch}'.") # keep only steps for this epoch steps = [s for i, s in enumerate(steps) if epochs[i] == epoch] @@ -51,7 +59,9 @@ def find_epoch_step(root: str | Path, epoch: int = "last", step: int = "last") - if step == "last": step = max(steps) elif step not in steps: - raise FileNotFoundError(f"No checkpoint in '{root}' for epoch '{epoch}', step '{step}'") + closest_step = min(steps, key=lambda s: abs(s - step)) + raise FileNotFoundError(f"No checkpoint in '{root}' for epoch '{epoch}', step '{step}'. " + f"Closest step for this epoch is '{closest_step}'.") return epoch, step @@ -60,7 +70,7 @@ def find_checkpoint(root: str | Path = "lightning_logs", version: int = "last", """ Helper method to find a lightning checkpoint based on version, epoch and step numbers. - @param root: logs root directory. Usually "lightning_logs/" + @param root: logs root directory. Usually "lightning_logs/" or "lightning_logs/" @param version: version number or "last" @param epoch: epoch number or "last" @param step: step number or "last" @@ -69,7 +79,7 @@ def find_checkpoint(root: str | Path = "lightning_logs", version: int = "last", root = Path(root) if not root.is_dir(): - raise ValueError(f"Root directory '{root}' does not exist") + raise ValueError(f"Checkpoint root directory '{root}' does not exist") # get existing version number or error version = find_version(root, version) @@ -87,7 +97,4 @@ def find_checkpoint(root: str | Path = "lightning_logs", version: int = "last", checkpoint = checkpoint_folder / f"epoch={epoch}-step={step}.ckpt" - if not checkpoint.is_file(): - raise FileNotFoundError(f"{version=}, {epoch=}, {step=}") - return str(checkpoint) diff --git a/tests/module_tests/__init__.py b/tests/module_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/module_tests/test_trainable.py b/tests/module_tests/test_trainable.py new file mode 100644 index 0000000..31283cf --- /dev/null +++ b/tests/module_tests/test_trainable.py @@ -0,0 +1,145 @@ + +import pytest + +import torch +import torch.nn as nn +from torch.utils.data import TensorDataset + +from pathlib import Path + +from lightning_trainable.hparams import HParams +from lightning_trainable.trainable import Trainable, TrainableHParams +from lightning_trainable.utils import find_checkpoint + + +@pytest.fixture +def dummy_dataset(): + return TensorDataset(torch.randn(128, 2)) + + +@pytest.fixture +def dummy_network(): + return nn.Linear(2, 2) + + +@pytest.fixture +def dummy_hparams_cls(): + class DummyHParams(TrainableHParams): + max_epochs: int = 10 + batch_size: int = 4 + accelerator: str = "cpu" + + return DummyHParams + + +@pytest.fixture +def dummy_hparams(dummy_hparams_cls): + return dummy_hparams_cls() + + +@pytest.fixture +def dummy_model_cls(dummy_network, dummy_dataset, dummy_hparams_cls): + class DummyModel(Trainable): + hparams: dummy_hparams_cls + + def __init__(self, hparams): + super().__init__(hparams, train_data=dummy_dataset, val_data=dummy_dataset, test_data=dummy_dataset) + self.network = dummy_network + + def compute_metrics(self, batch, batch_idx) -> dict: + return dict( + loss=torch.tensor(0.0, requires_grad=True) + ) + + return DummyModel + + +@pytest.fixture +def dummy_model(dummy_model_cls, dummy_hparams): + return dummy_model_cls(dummy_hparams) + + +def test_fit(dummy_model): + train_metrics = dummy_model.fit() + + assert isinstance(train_metrics, dict) + assert "training/loss" in train_metrics + assert "validation/loss" in train_metrics + + +def test_fit_fast(dummy_model): + loss = dummy_model.fit_fast(device="cpu") + + assert torch.isclose(loss, torch.tensor(0.0)) + + +def test_hparams_copy(dummy_hparams, dummy_hparams_cls): + assert isinstance(dummy_hparams, HParams) + assert isinstance(dummy_hparams, TrainableHParams) + assert isinstance(dummy_hparams, dummy_hparams_cls) + + hparams_copy = dummy_hparams.copy() + + assert isinstance(hparams_copy, HParams) + assert isinstance(dummy_hparams, TrainableHParams) + assert isinstance(dummy_hparams, dummy_hparams_cls) + + +def test_hparams_invariant(dummy_model_cls, dummy_hparams): + """ Ensure HParams are left unchanged after instantiation and training """ + hparams = dummy_hparams.copy() + + dummy_model1 = dummy_model_cls(hparams) + + assert hparams == dummy_hparams + + dummy_model1.fit() + + assert hparams == dummy_hparams + + +def test_checkpoint(dummy_model): + # TODO: temp directory + + dummy_model.fit() + + checkpoint = find_checkpoint() + + assert Path(checkpoint).is_file() + + +def test_nested_checkpoint(dummy_model_cls, dummy_hparams_cls): + + class MyHParams(dummy_hparams_cls): + pass + + class MyModel(dummy_model_cls): + hparams: MyHParams + + def __init__(self, hparams): + super().__init__(hparams) + + hparams = MyHParams() + model = MyModel(hparams) + + assert model._hparams_name == "hparams" + + +def test_continue_training(dummy_model): + print("Starting Training.") + dummy_model.fit() + + print("Finished Training. Loading Checkpoint.") + checkpoint = find_checkpoint() + + trained_model = dummy_model.__class__.load_from_checkpoint(checkpoint) + + print("Continuing Training.") + trained_model.fit( + trainer_kwargs=dict(max_epochs=2 * dummy_model.hparams.max_epochs), + fit_kwargs=dict(ckpt_path=checkpoint) + ) + + print("Finished Continued Training.") + + # TODO: add check that the model was actually trained for 2x epochs diff --git a/tests/test_trainable.py b/tests/test_trainable.py deleted file mode 100644 index b5e8ab3..0000000 --- a/tests/test_trainable.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch -from torch.utils.data import TensorDataset, Dataset - -from lightning_trainable import Trainable, TrainableHParams - - -def test_instantiate(): - hparams = TrainableHParams(max_epochs=10, batch_size=32) - Trainable(hparams) - - -def test_simple_model(): - class SimpleTrainable(Trainable): - def __init__(self, hparams: TrainableHParams | dict, - train_data: Dataset = None, - val_data: Dataset = None, - test_data: Dataset = None - ): - super().__init__(hparams, train_data, val_data, test_data) - self.param = torch.nn.Parameter(torch.randn(8, 1)) - - def compute_metrics(self, batch, batch_idx) -> dict: - return { - "loss": ((batch[0] @ self.param) ** 2).mean() - } - - train_data = TensorDataset(torch.randn(128, 8)) - - hparams = TrainableHParams( - accelerator="cpu", - max_epochs=10, - batch_size=32, - lr_scheduler="onecyclelr" - ) - model = SimpleTrainable(hparams, train_data=train_data) - model.fit() - - -def test_double_train(): - class SimpleTrainable(Trainable): - def __init__(self, hparams: TrainableHParams | dict, - train_data: Dataset = None, - val_data: Dataset = None, - test_data: Dataset = None - ): - super().__init__(hparams, train_data, val_data, test_data) - self.param = torch.nn.Parameter(torch.randn(8, 1)) - - def compute_metrics(self, batch, batch_idx) -> dict: - return { - "loss": ((batch[0] @ self.param) ** 2).mean() - } - - hparams = TrainableHParams( - accelerator="cpu", - max_epochs=1, - batch_size=8, - optimizer=dict( - name="adam", - lr=1e-3, - ) - ) - - train_data = TensorDataset(torch.randn(128, 8)) - - t1 = SimpleTrainable(hparams, train_data=train_data, val_data=train_data) - t1.fit() - - t2 = SimpleTrainable(hparams, train_data=train_data, val_data=train_data) - t2.fit() diff --git a/tests/test_unet.py b/tests/test_unet.py deleted file mode 100644 index 069b707..0000000 --- a/tests/test_unet.py +++ /dev/null @@ -1,60 +0,0 @@ - -import pytest - -import torch - -from lightning_trainable.modules import UNet, UNetHParams - - -@pytest.mark.parametrize("skip_mode", ["add", "concat", "none"]) -@pytest.mark.parametrize("width", [32, 48, 64]) -@pytest.mark.parametrize("height", [32, 48, 64]) -@pytest.mark.parametrize("channels", [1, 3, 5]) -def test_basic(skip_mode, channels, width, height): - hparams = UNetHParams( - input_shape=(channels, height, width), - output_shape=(channels, height, width), - down_blocks=[ - dict(channels=[16, 32], kernel_sizes=[3, 3, 3]), - dict(channels=[32, 64], kernel_sizes=[3, 3, 3]), - ], - up_blocks=[ - dict(channels=[64, 32], kernel_sizes=[3, 3, 3]), - dict(channels=[32, 16], kernel_sizes=[3, 3, 3]), - ], - bottom_widths=[32, 32], - skip_mode=skip_mode, - activation="relu", - ) - unet = UNet(hparams) - - x = torch.randn(1, *hparams.input_shape) - y = unet(x) - assert y.shape == (1, *hparams.output_shape) - - -@pytest.mark.parametrize("skip_mode", ["concat", "none"]) -@pytest.mark.parametrize("width", [32, 48, 64]) -@pytest.mark.parametrize("height", [32, 48, 64]) -@pytest.mark.parametrize("channels", [1, 3, 5]) -def test_inconsistent_channels(skip_mode, channels, width, height): - hparams = UNetHParams( - input_shape=(channels, height, width), - output_shape=(channels, height, width), - down_blocks=[ - dict(channels=[16, 24], kernel_sizes=[3, 3, 3]), - dict(channels=[48, 17], kernel_sizes=[3, 3, 3]), - ], - up_blocks=[ - dict(channels=[57, 31], kernel_sizes=[3, 3, 3]), - dict(channels=[12, 87], kernel_sizes=[3, 3, 3]), - ], - bottom_widths=[32, 32], - skip_mode=skip_mode, - activation="relu", - ) - unet = UNet(hparams) - - x = torch.randn(1, *hparams.input_shape) - y = unet(x) - assert y.shape == (1, *hparams.output_shape)