diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml
index 36f78e4..14fcbe2 100644
--- a/.github/workflows/tests.yaml
+++ b/.github/workflows/tests.yaml
@@ -9,8 +9,8 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
- os: [ubuntu-latest, windows-latest]
- python-version: ["3.10"]
+ os: [ubuntu-latest, windows-latest, macos-latest]
+ python-version: ["3.10", "3.11"]
steps:
- uses: actions/checkout@v3
diff --git a/pyproject.toml b/pyproject.toml
index 2004e9e..0cef9e1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -9,7 +9,7 @@ description = "A light-weight lightning_trainable module for pytorch-lightning."
readme = "README.md"
requires-python = ">=3.10"
license = { file = "LICENSE" }
-keywords = ["Machine Learning", "PyTorch", "PyTorch-Lightning"]
+keywords = ["Machine-Learning", "PyTorch", "PyTorch-Lightning"]
authors = [
{ name = "Lars Kühmichel", email = "lars.kuehmichel@stud.uni-heidelberg.de" }
@@ -57,7 +57,9 @@ tests = [
]
experiments = [
- # these are just recommended packages to run experiments
+ # required
+ "ray[tune] ~= 2.4",
+ # recommended
"numpy ~= 1.24",
"matplotlib ~= 3.7",
"jupyterlab ~= 3.6",
diff --git a/src/lightning_trainable/callbacks/epoch_progress_bar.py b/src/lightning_trainable/callbacks/epoch_progress_bar.py
index 9dd973f..a375c97 100644
--- a/src/lightning_trainable/callbacks/epoch_progress_bar.py
+++ b/src/lightning_trainable/callbacks/epoch_progress_bar.py
@@ -2,7 +2,11 @@
from lightning.pytorch.callbacks import ProgressBar
from lightning.pytorch.callbacks.progress.tqdm_progress import Tqdm
+from lightning_trainable.utils import deprecate
+
+@deprecate("EpochProgressBar causes issues when continuing training or using multi-GPU. "
+ "Use the default Lightning ProgressBar instead.")
class EpochProgressBar(ProgressBar):
def __init__(self):
super().__init__()
@@ -25,6 +29,8 @@ def on_train_end(self, trainer, pl_module):
self.bar.close()
+@deprecate("StepProgressBar causes issues when continuing training or using multi-GPU. "
+ "Use the default Lightning ProgressBar instead.")
class StepProgressBar(ProgressBar):
def __init__(self):
super().__init__()
diff --git a/src/lightning_trainable/datasets/core/__init__.py b/src/lightning_trainable/datasets/core/__init__.py
index a0f49bf..6e526bd 100644
--- a/src/lightning_trainable/datasets/core/__init__.py
+++ b/src/lightning_trainable/datasets/core/__init__.py
@@ -1 +1,2 @@
+from .distribution_dataset import DistributionDataset
from .joint import JointDataset, JointIterableDataset
diff --git a/src/lightning_trainable/hparams/attribute_dict.py b/src/lightning_trainable/hparams/attribute_dict.py
index dc484a1..98a4534 100644
--- a/src/lightning_trainable/hparams/attribute_dict.py
+++ b/src/lightning_trainable/hparams/attribute_dict.py
@@ -14,3 +14,8 @@ def __getattribute__(self, item):
def __setattr__(self, key, value):
self[key] = value
+
+ def copy(self):
+ # copies of AttributeDicts should be AttributeDicts
+ # see also https://github.com/LarsKue/lightning-trainable/issues/13
+ return self.__class__(**super().copy())
diff --git a/src/lightning_trainable/hparams/types/choice.py b/src/lightning_trainable/hparams/types/choice.py
index 4c0c545..00315ca 100644
--- a/src/lightning_trainable/hparams/types/choice.py
+++ b/src/lightning_trainable/hparams/types/choice.py
@@ -11,6 +11,9 @@ def __call__(cls, *choices):
namespace = {"choices": choices}
return type(name, bases, namespace)
+ def __repr__(cls):
+ return f"Choice{cls.choices!r}"
+
class Choice(metaclass=ChoiceMeta):
"""
diff --git a/src/lightning_trainable/hparams/types/range.py b/src/lightning_trainable/hparams/types/range.py
index b982b93..a20b739 100644
--- a/src/lightning_trainable/hparams/types/range.py
+++ b/src/lightning_trainable/hparams/types/range.py
@@ -21,6 +21,9 @@ def __call__(cls, lower: float | int, upper: float | int, exclude: str | None =
namespace = {"lower": lower, "upper": upper, "exclude": exclude}
return type(name, bases, namespace)
+ def __repr__(self):
+ return f"Range({self.lower!r}, {self.upper!r}, exclude={self.exclude!r})"
+
class Range(metaclass=RangeMeta):
"""
diff --git a/src/lightning_trainable/metrics/__init__.py b/src/lightning_trainable/metrics/__init__.py
index 06a1dab..30bcacd 100644
--- a/src/lightning_trainable/metrics/__init__.py
+++ b/src/lightning_trainable/metrics/__init__.py
@@ -1 +1,4 @@
from .accuracy import accuracy
+from .error import error
+from .sinkhorn import sinkhorn_auto as sinkhorn
+from .wasserstein import wasserstein
diff --git a/src/lightning_trainable/metrics/error.py b/src/lightning_trainable/metrics/error.py
new file mode 100644
index 0000000..ec4b458
--- /dev/null
+++ b/src/lightning_trainable/metrics/error.py
@@ -0,0 +1,6 @@
+
+from .accuracy import accuracy
+
+
+def error(logits, targets, *, k=1):
+ return 1.0 - accuracy(logits, targets, k=k)
diff --git a/src/lightning_trainable/metrics/sinkhorn.py b/src/lightning_trainable/metrics/sinkhorn.py
new file mode 100644
index 0000000..623ecdc
--- /dev/null
+++ b/src/lightning_trainable/metrics/sinkhorn.py
@@ -0,0 +1,107 @@
+import warnings
+
+import torch
+from torch import Tensor
+
+import torch.nn.functional as F
+import numpy as np
+
+
+def sinkhorn(a: Tensor, b: Tensor, cost: Tensor, epsilon: float, steps: int = 1000) -> Tensor:
+ """
+ Computes the Sinkhorn optimal transport plan from sample weights of two distributions.
+ This version does not use log-space computations, but allows for zero or negative weights.
+
+ @param a: Sample weights from the first distribution in shape (n,)
+ @param b: Sample weights from the second distribution in shape (m,)
+ @param cost: Cost matrix in shape (n, m).
+ @param epsilon: Entropic regularization parameter.
+ @param steps: Number of iterations.
+ """
+ if cost.shape != (len(a), len(b)):
+ raise ValueError(f"Expected cost to have shape {(len(a), len(b))}, but got {cost.shape}.")
+
+ gain = torch.exp(-cost / epsilon)
+
+ if gain.mean() < 1e-30:
+ warnings.warn(f"Detected low bandwidth ({epsilon:.1e}) relative to cost ({cost.mean().item():.1e}). "
+ f"You may experience numerical instabilities. Consider increasing epsilon or using sinkhorn_log.")
+
+ # Initialize the dual variables.
+ u = torch.ones(len(a), dtype=a.dtype, device=a.device)
+ v = torch.ones(len(b), dtype=b.dtype, device=b.device)
+
+ # Compute the Sinkhorn iterations.
+ for _ in range(steps):
+ v = b / (torch.matmul(gain.T, u) + 1e-50)
+ u = a / (torch.matmul(gain, v) + 1e-50)
+
+ # Return the transport plan.
+ return u[:, None] * gain * v[None, :]
+
+
+def sinkhorn_log(log_a: Tensor, log_b: Tensor, cost: Tensor, epsilon: float, steps: int = 1000) -> Tensor:
+ """
+ Computes the Sinkhorn optimal transport plan from sample weights of two distributions.
+ This version uses log-space computations to avoid numerical instabilities, but disallows zero or negative weights.
+
+ @param log_a: Log sample weights from the first distribution in shape (n,)
+ @param log_b: Log sample weights from the second distribution in shape (m,)
+ @param cost: Cost matrix in shape (n, m).
+ @param epsilon: Entropic regularization parameter.
+ @param steps: Number of iterations.
+ """
+ if cost.shape != (len(log_a), len(log_b)):
+ raise ValueError(f"Expected cost to have shape {(len(log_a), len(log_b))}, but got {cost.shape}.")
+
+ log_gain = -cost / epsilon
+
+ # Initialize the dual variables.
+ log_u = torch.zeros(len(log_a), dtype=log_a.dtype, device=log_a.device)
+ log_v = torch.zeros(len(log_b), dtype=log_b.dtype, device=log_b.device)
+
+ # Compute the Sinkhorn iterations.
+ for _ in range(steps):
+ log_v = log_b - torch.logsumexp(log_gain + log_u[:, None], dim=0)
+ log_u = log_a - torch.logsumexp(log_gain + log_v[None, :], dim=1)
+
+ plan = torch.exp(log_u[:, None] + log_gain + log_v[None, :])
+
+ if not torch.allclose(len(log_b) * plan.sum(dim=0), torch.ones(len(log_b), device=plan.device)) or not torch.allclose(len(log_a) * plan.sum(dim=1), torch.ones(len(log_a), device=plan.device)):
+ warnings.warn(f"Sinkhorn did not converge. Consider increasing epsilon or number of iterations.")
+
+ # Return the transport plan.
+ return plan
+
+
+def sinkhorn_auto(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = 1.0, steps: int = 1000) -> Tensor:
+ """
+ Computes the Sinkhorn optimal transport plan from samples from two distributions.
+ See also: sinkhorn_log
+
+ @param x: Samples from the first distribution in shape (n, ...).
+ @param y: Samples from the second distribution in shape (m, ...).
+ @param cost: Optional cost matrix in shape (n, m).
+ If not provided, the Euclidean distance is used.
+ @param epsilon: Optional entropic regularization parameter.
+ This parameter is normalized to the half-mean of the cost matrix. This helps keep the value independent
+ of the data dimensionality. Note that this behaviour is exclusive to this method; sinkhorn_log only accepts
+ the raw entropic regularization value.
+ @param steps: Number of iterations.
+ """
+ if x.shape[1:] != y.shape[1:]:
+ raise ValueError(f"Expected x and y to live in the same feature space, "
+ f"but got {x.shape[1:]} and {y.shape[1:]}.")
+ if cost is None:
+ cost = x[:, None] - y[None, :]
+ cost = torch.flatten(cost, start_dim=2)
+ cost = torch.linalg.norm(cost, dim=-1)
+
+ # Initialize epsilon independent of the data dimension (i.e. dependent on the mean cost)
+ epsilon = epsilon * cost.mean() / 2
+
+ # Initialize the sample weights.
+ log_a = torch.zeros(len(x), device=x.device) - np.log(len(x))
+ log_b = torch.zeros(len(y), device=y.device) - np.log(len(y))
+
+ return sinkhorn_log(log_a, log_b, cost, epsilon, steps)
diff --git a/src/lightning_trainable/metrics/wasserstein.py b/src/lightning_trainable/metrics/wasserstein.py
new file mode 100644
index 0000000..1fdb447
--- /dev/null
+++ b/src/lightning_trainable/metrics/wasserstein.py
@@ -0,0 +1,18 @@
+
+import torch
+from torch import Tensor
+
+from .sinkhorn import sinkhorn_auto
+
+
+def wasserstein(x: Tensor, y: Tensor, cost: Tensor = None, epsilon: float = 0.1, steps: int = 10) -> Tensor:
+ """
+ Computes the Wasserstein distance between two distributions.
+ See also: sinkhorn_auto
+ """
+ if cost is None:
+ cost = x[:, None] - y[None, :]
+ cost = torch.flatten(cost, start_dim=2)
+ cost = torch.linalg.norm(cost, dim=-1)
+
+ return torch.sum(sinkhorn_auto(x, y, cost, epsilon, steps) * cost)
diff --git a/src/lightning_trainable/modules/__init__.py b/src/lightning_trainable/modules/__init__.py
index b5cb254..a710cc6 100644
--- a/src/lightning_trainable/modules/__init__.py
+++ b/src/lightning_trainable/modules/__init__.py
@@ -1,4 +1,3 @@
-from .convolutional import ConvolutionalNetwork, ConvolutionalNetworkHParams
from .fully_connected import FullyConnectedNetwork, FullyConnectedNetworkHParams
from .hparams_module import HParamsModule
-from .unet import UNet, UNetHParams, UNetBlockHParams
+from .simple_unet import SimpleUNet, SimpleUNetHParams
diff --git a/src/lightning_trainable/modules/convolutional/__init__.py b/src/lightning_trainable/modules/convolutional/__init__.py
deleted file mode 100644
index 7111d1e..0000000
--- a/src/lightning_trainable/modules/convolutional/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from .block import ConvolutionalBlock
-from .block_hparams import ConvolutionalBlockHParams
-
-from .network import ConvolutionalNetwork
-from .hparams import ConvolutionalNetworkHParams
diff --git a/src/lightning_trainable/modules/convolutional/block.py b/src/lightning_trainable/modules/convolutional/block.py
deleted file mode 100644
index e83f743..0000000
--- a/src/lightning_trainable/modules/convolutional/block.py
+++ /dev/null
@@ -1,74 +0,0 @@
-import torch
-import torch.nn as nn
-
-import lightning_trainable.utils as utils
-
-from ..hparams_module import HParamsModule
-
-from .block_hparams import ConvolutionalBlockHParams
-
-
-class ConvolutionalBlock(HParamsModule):
- """
- Implements a series of convolutions, each followed by an activation function.
- """
-
- hparams: ConvolutionalBlockHParams
-
- def __init__(self, hparams: dict | ConvolutionalBlockHParams):
- super().__init__(hparams)
-
- self.network = self.configure_network()
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.network(x)
-
- def configure_network(self):
- # construct convolutions
- convolutions = []
- cck = zip(self.hparams.channels[:-1], self.hparams.channels[1:], self.hparams.kernel_sizes)
- for in_channels, out_channels, kernel_size in cck:
- conv = nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- padding=self.hparams.padding,
- dilation=self.hparams.dilation,
- groups=self.hparams.groups,
- bias=self.hparams.bias,
- padding_mode=self.hparams.padding_mode,
- )
-
- convolutions.append(conv)
-
- # construct activations
- activations = []
- for _ in range(len(self.hparams.channels) - 2):
- activations.append(utils.get_activation(self.hparams.activation)(inplace=True))
-
- layers = list(utils.zip(convolutions, activations, exhaustive=True, nested=False))
-
- # add pooling layer if requested
- if self.hparams.pool:
- match self.hparams.pool_direction:
- case "up":
- if self.hparams.pool_position == "first":
- channels = self.hparams.channels[0]
- else:
- channels = self.hparams.channels[-1]
-
- pool = nn.ConvTranspose2d(channels, channels, 2, 2)
- case "down":
- pool = nn.MaxPool2d(2, 2)
- case _:
- raise NotImplementedError(f"Unrecognized pool direction '{self.hparams.pool_direction}'.")
-
- match self.hparams.pool_position:
- case "first":
- layers.insert(0, pool)
- case "last":
- layers.append(pool)
- case _:
- raise NotImplementedError(f"Unrecognized pool position '{self.hparams.pool_position}'.")
-
- return nn.Sequential(*layers)
diff --git a/src/lightning_trainable/modules/convolutional/block_hparams.py b/src/lightning_trainable/modules/convolutional/block_hparams.py
deleted file mode 100644
index b901ea9..0000000
--- a/src/lightning_trainable/modules/convolutional/block_hparams.py
+++ /dev/null
@@ -1,25 +0,0 @@
-from lightning_trainable.hparams import HParams, Choice
-
-
-class ConvolutionalBlockHParams(HParams):
- channels: list[int]
- kernel_sizes: list[int]
- activation: str = "relu"
- padding: str | int = 0
- dilation: int = 1
- groups: int = 1
- bias: bool = True
- padding_mode: str = "zeros"
-
- pool: bool = False
- pool_direction: Choice("up", "down") = "down"
- pool_position: Choice("first", "last") = "last"
-
- @classmethod
- def validate_parameters(cls, hparams):
- hparams = super().validate_parameters(hparams)
-
- if not len(hparams.channels) == len(hparams.kernel_sizes):
- raise ValueError(f"{cls.__name__} needs same number of channels and kernel sizes.")
-
- return hparams
diff --git a/src/lightning_trainable/modules/convolutional/hparams.py b/src/lightning_trainable/modules/convolutional/hparams.py
deleted file mode 100644
index 046e54e..0000000
--- a/src/lightning_trainable/modules/convolutional/hparams.py
+++ /dev/null
@@ -1,8 +0,0 @@
-
-from lightning_trainable.hparams import HParams
-
-from .block_hparams import ConvolutionalBlockHParams
-
-
-class ConvolutionalNetworkHParams(HParams):
- block_hparams: list[ConvolutionalBlockHParams]
diff --git a/src/lightning_trainable/modules/convolutional/network.py b/src/lightning_trainable/modules/convolutional/network.py
deleted file mode 100644
index c4ad3ad..0000000
--- a/src/lightning_trainable/modules/convolutional/network.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import torch
-import torch.nn as nn
-
-
-from ..hparams_module import HParamsModule
-from ..sequential_mixin import SequentialMixin
-
-from .block import ConvolutionalBlock
-from .hparams import ConvolutionalNetworkHParams
-
-
-class ConvolutionalNetwork(SequentialMixin, HParamsModule):
- """
- Implements a series of pooled convolutional blocks.
- """
- hparams: ConvolutionalNetworkHParams
-
- def __init__(self, hparams):
- super().__init__(hparams)
- self.network = self.configure_network()
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.network(x)
-
- def configure_network(self):
- blocks = [ConvolutionalBlock(hparams) for hparams in self.hparams.block_hparams]
-
- return nn.Sequential(*blocks)
diff --git a/src/lightning_trainable/modules/fully_connected/hparams.py b/src/lightning_trainable/modules/fully_connected/hparams.py
index 2d3dccb..7a981d1 100644
--- a/src/lightning_trainable/modules/fully_connected/hparams.py
+++ b/src/lightning_trainable/modules/fully_connected/hparams.py
@@ -8,3 +8,6 @@ class FullyConnectedNetworkHParams(HParams):
layer_widths: list[int]
activation: str = "relu"
+
+ norm: Choice("none", "batch", "layer") = "none"
+ dropout: float = 0.0
diff --git a/src/lightning_trainable/modules/fully_connected/network.py b/src/lightning_trainable/modules/fully_connected/network.py
index d1b0583..b4a133e 100644
--- a/src/lightning_trainable/modules/fully_connected/network.py
+++ b/src/lightning_trainable/modules/fully_connected/network.py
@@ -1,4 +1,6 @@
-import torch
+
+from torch import Tensor
+
import torch.nn as nn
from lightning_trainable.utils import get_activation
@@ -19,22 +21,63 @@ def __init__(self, hparams):
self.network = self.configure_network()
- def forward(self, x: torch.Tensor) -> torch.Tensor:
+ def forward(self, x: Tensor) -> Tensor:
return self.network(x)
def configure_network(self):
- layers = []
- widths = [self.hparams.input_dims, *self.hparams.layer_widths, self.hparams.output_dims]
+ widths = self.hparams.layer_widths
+
+ # input layer
+ input_linear = self.configure_linear(self.hparams.input_dims, widths[0])
+ input_activation = self.configure_activation()
+ layers = [input_linear, input_activation]
+
+ # hidden layers
+ for (in_features, out_features) in zip(widths[:-1], widths[1:]):
+ dropout = self.configure_dropout()
+ norm = self.configure_norm(in_features)
+
+ activation = self.configure_activation()
+ linear = self.configure_linear(in_features, out_features)
- if self.hparams.input_dims == "lazy":
- widths = widths[1:]
- layers = [nn.LazyLinear(widths[0]), get_activation(self.hparams.activation)(inplace=True)]
+ if dropout is not None:
+ layers.append(dropout)
- for (w1, w2) in zip(widths[:-1], widths[1:]):
- layers.append(nn.Linear(w1, w2))
- layers.append(get_activation(self.hparams.activation)(inplace=True))
+ if norm is not None:
+ layers.append(norm)
- # remove last activation
- layers = layers[:-1]
+ layers.append(linear)
+ layers.append(activation)
+
+ # output layer
+ output_linear = self.configure_linear(widths[-1], self.hparams.output_dims)
+ layers.append(output_linear)
return nn.Sequential(*layers)
+
+ def configure_linear(self, in_features, out_features) -> nn.Module:
+ match in_features:
+ case "lazy":
+ return nn.LazyLinear(out_features)
+ case int() as in_features:
+ return nn.Linear(in_features, out_features)
+ case other:
+ raise NotImplementedError(f"Unrecognized input_dims value: '{other}'")
+
+ def configure_activation(self) -> nn.Module:
+ return get_activation(self.hparams.activation)(inplace=True)
+
+ def configure_dropout(self) -> nn.Module | None:
+ if self.hparams.dropout > 0:
+ return nn.Dropout(self.hparams.dropout)
+
+ def configure_norm(self, num_features) -> nn.Module | None:
+ match self.hparams.norm:
+ case "none":
+ return None
+ case "batch":
+ return nn.BatchNorm1d(num_features)
+ case "layer":
+ return nn.LayerNorm(num_features)
+ case other:
+ raise NotImplementedError(f"Unrecognized norm value: '{other}'")
diff --git a/src/lightning_trainable/modules/hparams_module.py b/src/lightning_trainable/modules/hparams_module.py
index 225ce17..d8a0691 100644
--- a/src/lightning_trainable/modules/hparams_module.py
+++ b/src/lightning_trainable/modules/hparams_module.py
@@ -17,4 +17,7 @@ def __init__(self, hparams: HParams | dict):
self.hparams = hparams
def __init_subclass__(cls, **kwargs):
- cls.hparams_type = cls.__annotations__["hparams"]
+ hparams_type = cls.__annotations__.get("hparams")
+ if hparams_type is not None:
+ # only overwrite hparams_type if it is defined by the child class
+ cls.hparams_type = hparams_type
diff --git a/src/lightning_trainable/modules/simple_unet/__init__.py b/src/lightning_trainable/modules/simple_unet/__init__.py
new file mode 100644
index 0000000..64c06ad
--- /dev/null
+++ b/src/lightning_trainable/modules/simple_unet/__init__.py
@@ -0,0 +1,3 @@
+
+from .hparams import SimpleUNetHParams
+from .network import SimpleUNet
diff --git a/src/lightning_trainable/modules/simple_unet/down_block.py b/src/lightning_trainable/modules/simple_unet/down_block.py
new file mode 100644
index 0000000..3a3cd4f
--- /dev/null
+++ b/src/lightning_trainable/modules/simple_unet/down_block.py
@@ -0,0 +1,29 @@
+
+import torch
+import torch.nn as nn
+from torch.nn import Module
+
+from lightning_trainable.utils import get_activation
+
+
+class SimpleUNetDownBlock(Module):
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, block_size: int = 2, activation: str = "relu"):
+ super().__init__()
+
+ self.channels = torch.linspace(in_channels, out_channels, block_size + 1, dtype=torch.int64).tolist()
+
+ layers = []
+ for c1, c2 in zip(self.channels[:-2], self.channels[1:-1]):
+ layers.append(nn.Conv2d(c1, c2, kernel_size, padding="same"))
+ layers.append(get_activation(activation)(inplace=True))
+
+ layers.append(nn.Conv2d(self.channels[-2], self.channels[-1], kernel_size, padding="same"))
+ layers.append(nn.MaxPool2d(2))
+
+ self.block = nn.Sequential(*layers)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.block(x)
+
+ def extra_repr(self) -> str:
+ return f"in_channels={self.channels[0]}, out_channels={self.channels[-1]}, kernel_size={self.block[0].kernel_size[0]}, block_size={len(self.channels) - 1}"
diff --git a/src/lightning_trainable/modules/simple_unet/hparams.py b/src/lightning_trainable/modules/simple_unet/hparams.py
new file mode 100644
index 0000000..bb4714d
--- /dev/null
+++ b/src/lightning_trainable/modules/simple_unet/hparams.py
@@ -0,0 +1,23 @@
+from lightning_trainable.hparams import HParams
+
+
+class SimpleUNetHParams(HParams):
+ input_shape: tuple[int, int, int]
+ conditions: int = 0
+
+ channels: list[int]
+ kernel_sizes: list[int]
+ fc_widths: list[int]
+ activation: str = "ReLU"
+
+ block_size: int = 2
+
+ @classmethod
+ def validate_parameters(cls, hparams):
+ hparams = super().validate_parameters(hparams)
+
+ if len(hparams.channels) + 2 != len(hparams.kernel_sizes):
+ raise ValueError(f"Number of channels ({len(hparams.channels)}) + 2 must be equal "
+ f"to the number of kernel sizes ({len(hparams.kernel_sizes)})")
+
+ return hparams
diff --git a/src/lightning_trainable/modules/simple_unet/network.py b/src/lightning_trainable/modules/simple_unet/network.py
new file mode 100644
index 0000000..f2fce9a
--- /dev/null
+++ b/src/lightning_trainable/modules/simple_unet/network.py
@@ -0,0 +1,73 @@
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+from ..fully_connected import FullyConnectedNetwork
+from ..hparams_module import HParamsModule
+
+from .down_block import SimpleUNetDownBlock
+from .hparams import SimpleUNetHParams
+from .up_block import SimpleUNetUpBlock
+
+
+class SimpleUNet(HParamsModule):
+ hparams: SimpleUNetHParams
+
+ def __init__(self, hparams: dict | SimpleUNetHParams):
+ super().__init__(hparams)
+
+ channels = [self.hparams.input_shape[0], *self.hparams.channels]
+
+ self.down_blocks = nn.ModuleList([
+ SimpleUNetDownBlock(c1, c2, kernel_size, self.hparams.block_size, self.hparams.activation)
+ for c1, c2, kernel_size in zip(channels[:-1], channels[1:], self.hparams.kernel_sizes)
+ ])
+
+ fc_channels = self.hparams.channels[-1]
+ height = self.hparams.input_shape[1] // 2 ** (len(channels) - 1)
+ width = self.hparams.input_shape[2] // 2 ** (len(channels) - 1)
+
+ fc_hparams = dict(
+ input_dims=self.hparams.conditions + fc_channels * height * width,
+ output_dims=fc_channels * height * width,
+ activation=self.hparams.activation,
+ layer_widths=self.hparams.fc_widths,
+ )
+ self.fc = FullyConnectedNetwork(fc_hparams)
+ self.up_blocks = nn.ModuleList([
+ SimpleUNetUpBlock(c2, c1, kernel_size, self.hparams.block_size, self.hparams.activation)
+ for c1, c2, kernel_size in zip(channels[:-1], channels[1:], self.hparams.kernel_sizes)
+ ][::-1])
+
+ def forward(self, image: Tensor, condition: Tensor = None) -> Tensor:
+ residuals = []
+ for block in self.down_blocks:
+ image = block(image)
+ residuals.append(image)
+
+ shape = image.shape
+ image = image.flatten(start_dim=1)
+
+ if condition is not None:
+ image = torch.cat([image, condition], dim=1)
+
+ image = self.fc(image)
+
+ image = image.reshape(shape)
+
+ for block in self.up_blocks:
+ residual = residuals.pop()
+ image = block(image + residual)
+
+ return image
+
+ def down(self, image: Tensor) -> Tensor:
+ for block in self.down_blocks:
+ image = block(image)
+ return image
+
+ def up(self, image: Tensor) -> Tensor:
+ for block in self.up_blocks:
+ image = block(image)
+ return image
diff --git a/src/lightning_trainable/modules/simple_unet/up_block.py b/src/lightning_trainable/modules/simple_unet/up_block.py
new file mode 100644
index 0000000..b389626
--- /dev/null
+++ b/src/lightning_trainable/modules/simple_unet/up_block.py
@@ -0,0 +1,29 @@
+
+import torch
+import torch.nn as nn
+from torch.nn import Module
+
+from lightning_trainable.utils import get_activation
+
+
+class SimpleUNetUpBlock(Module):
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int, block_size: int = 2, activation: str = "relu"):
+ super().__init__()
+
+ self.channels = torch.linspace(in_channels, out_channels, block_size + 1, dtype=torch.int64).tolist()
+
+ layers = []
+ for c1, c2 in zip(self.channels[:-2], self.channels[1:-1]):
+ layers.append(nn.Conv2d(c1, c2, kernel_size, padding="same"))
+ layers.append(get_activation(activation)(inplace=True))
+
+ layers.append(nn.Conv2d(self.channels[-2], self.channels[-1], kernel_size, padding="same"))
+ layers.append(nn.ConvTranspose2d(self.channels[-1], self.channels[-1], 2, 2))
+
+ self.block = nn.Sequential(*layers)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.block(x)
+
+ def extra_repr(self) -> str:
+ return f"in_channels={self.channels[0]}, out_channels={self.channels[-1]}, kernel_size={self.block[0].kernel_size[0]}, block_size={len(self.channels) - 1}"
diff --git a/src/lightning_trainable/modules/unet/__init__.py b/src/lightning_trainable/modules/unet/__init__.py
deleted file mode 100644
index 964f0cb..0000000
--- a/src/lightning_trainable/modules/unet/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .hparams import UNetHParams, UNetBlockHParams
-from .network import UNet
diff --git a/src/lightning_trainable/modules/unet/hparams.py b/src/lightning_trainable/modules/unet/hparams.py
deleted file mode 100644
index 6afc972..0000000
--- a/src/lightning_trainable/modules/unet/hparams.py
+++ /dev/null
@@ -1,66 +0,0 @@
-from lightning_trainable.hparams import HParams, Choice
-
-
-class UNetBlockHParams(HParams):
- # this is a reduced variant of the ConvolutionalBlockHParams
- # where additional parameters are generated automatically by UNet
- channels: list[int]
- kernel_sizes: list[int]
-
- @classmethod
- def validate_parameters(cls, hparams):
- hparams = super().validate_parameters(hparams)
-
- if len(hparams.channels) + 1 != len(hparams.kernel_sizes):
- raise ValueError(f"Block needs one more kernel size than channels.")
-
- return hparams
-
-
-class UNetHParams(HParams):
- # input/output image size, not including batch dimensions
- input_shape: tuple[int, int, int]
- output_shape: tuple[int, int, int]
-
- # list of hparams for individual down/up blocks
- down_blocks: list[dict | UNetBlockHParams]
- up_blocks: list[dict | UNetBlockHParams]
-
- # hidden layer sizes for the bottom, fully connected part of the UNet
- bottom_widths: list[int]
-
- # skip connection mode
- skip_mode: Choice("add", "concat", "none") = "add"
- activation: str = "relu"
-
- @classmethod
- def validate_parameters(cls, hparams):
- hparams = super().validate_parameters(hparams)
-
- for i in range(len(hparams.down_blocks)):
- hparams.down_blocks[i] = UNetBlockHParams(**hparams.down_blocks[i])
- for i in range(len(hparams.up_blocks)):
- hparams.up_blocks[i] = UNetBlockHParams(**hparams.up_blocks[i])
-
- url = "https://github.com/LarsKue/lightning-trainable/"
- if hparams.input_shape[1:] != hparams.output_shape[1:]:
- raise ValueError(f"Different image sizes for input and output are not yet supported. "
- f"If you need this feature, please file an issue or pull request at {url}.")
-
- if hparams.input_shape[1] % 2 or hparams.input_shape[2] % 2:
- raise ValueError(f"Odd input shape is not yet supported. "
- f"If you need this feature, please file an issue or pull request at {url}.")
-
- minimum_size = 2 ** len(hparams.down_blocks)
- if hparams.input_shape[1] < minimum_size or hparams.input_shape[2] < minimum_size:
- raise ValueError(f"Input shape {hparams.input_shape[1:]} is too small for {len(hparams.down_blocks)} "
- f"down blocks. Minimum size is {(minimum_size, minimum_size)}.")
-
- if hparams.skip_mode == "add":
- # ensure matching number of channels for down output as up input
- for i, (down_block, up_block) in enumerate(zip(hparams.down_blocks, reversed(hparams.up_blocks))):
- if down_block["channels"][-1] != up_block["channels"][0]:
- raise ValueError(f"Output channels of down block {i} must match input channels of up block "
- f"{len(hparams.up_blocks) - (i + 1)} for skip mode '{hparams.skip_mode}'.")
-
- return hparams
diff --git a/src/lightning_trainable/modules/unet/network.py b/src/lightning_trainable/modules/unet/network.py
deleted file mode 100644
index f40a31a..0000000
--- a/src/lightning_trainable/modules/unet/network.py
+++ /dev/null
@@ -1,100 +0,0 @@
-import torch
-import torch.nn as nn
-
-import math
-
-from lightning_trainable.hparams import AttributeDict
-
-from ..convolutional import ConvolutionalBlock
-from ..fully_connected import FullyConnectedNetwork
-from ..hparams_module import HParamsModule
-
-from .hparams import UNetHParams
-from .skip_connection import SkipConnection
-from .temporary_flatten import TemporaryFlatten
-
-
-class UNet(HParamsModule):
- hparams: UNetHParams
-
- def __init__(self, hparams: dict | UNetHParams):
- super().__init__(hparams)
-
- self.network = self.configure_network(
- input_shape=self.hparams.input_shape,
- output_shape=self.hparams.output_shape,
- down_blocks=self.hparams.down_blocks,
- up_blocks=self.hparams.up_blocks,
- )
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.network(x)
-
- def configure_levels(self, input_shape: (int, int, int), output_shape: (int, int, int), down_blocks: list[dict], up_blocks: list[dict]):
- """
- Recursively configures the levels of the UNet.
- """
- if not down_blocks:
- return self.configure_fc(input_shape, output_shape)
-
- down_block = down_blocks[0]
- up_block = up_blocks[-1]
-
- down_hparams = AttributeDict(
- channels=[input_shape[0], *down_block["channels"]],
- kernel_sizes=down_block["kernel_sizes"],
- activation=self.hparams.activation,
- padding="same",
- pool=True,
- pool_direction="down",
- pool_position="last"
- )
- up_hparams = AttributeDict(
- channels=[*up_block["channels"], output_shape[0]],
- kernel_sizes=up_block["kernel_sizes"],
- activation=self.hparams.activation,
- padding="same",
- pool=True,
- pool_direction="up",
- pool_position="first"
- )
-
- next_input_shape = (down_hparams.channels[-1], input_shape[1] // 2, input_shape[2] // 2)
- next_output_shape = (up_hparams.channels[0], output_shape[1] // 2, output_shape[2] // 2)
-
- if self.hparams.skip_mode == "concat":
- up_hparams.channels[0] += down_hparams.channels[-1]
-
- down_block = ConvolutionalBlock(down_hparams)
- up_block = ConvolutionalBlock(up_hparams)
-
- next_level = self.configure_levels(
- input_shape=next_input_shape,
- output_shape=next_output_shape,
- down_blocks=down_blocks[1:],
- up_blocks=up_blocks[:-1],
- )
-
- return nn.Sequential(
- down_block,
- SkipConnection(next_level, mode=self.hparams.skip_mode),
- up_block,
- )
-
- def configure_fc(self, input_shape: (int, int, int), output_shape: (int, int, int)):
- """
- Configures the lowest level of the UNet as a fully connected network.
- """
- hparams = dict(
- input_dims=math.prod(input_shape),
- output_dims=math.prod(output_shape),
- layer_widths=self.hparams.bottom_widths,
- activation=self.hparams.activation,
- )
- return TemporaryFlatten(FullyConnectedNetwork(hparams), input_shape, output_shape)
-
- def configure_network(self, input_shape: (int, int, int), output_shape: (int, int, int), down_blocks: list[dict], up_blocks: list[dict]):
- """
- Configures the UNet.
- """
- return self.configure_levels(input_shape, output_shape, down_blocks, up_blocks)
diff --git a/src/lightning_trainable/modules/unet/skip_connection.py b/src/lightning_trainable/modules/unet/skip_connection.py
deleted file mode 100644
index c6ac170..0000000
--- a/src/lightning_trainable/modules/unet/skip_connection.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import torch
-import torch.nn as nn
-
-
-class SkipConnection(nn.Module):
- def __init__(self, inner: nn.Module, mode: str = "add"):
- super().__init__()
- self.inner = inner
- self.mode = mode
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- match self.mode:
- case "add":
- return self.inner(x) + x
- case "concat":
- return torch.cat((self.inner(x), x), dim=1)
- case "none":
- return self.inner(x)
- case other:
- raise NotImplementedError(f"Unrecognized skip connection mode '{other}'.")
-
- def extra_repr(self) -> str:
- return f"mode={self.mode}"
diff --git a/src/lightning_trainable/modules/unet/temporary_flatten.py b/src/lightning_trainable/modules/unet/temporary_flatten.py
deleted file mode 100644
index 4b23244..0000000
--- a/src/lightning_trainable/modules/unet/temporary_flatten.py
+++ /dev/null
@@ -1,14 +0,0 @@
-import torch
-import torch.nn as nn
-
-
-class TemporaryFlatten(nn.Module):
- def __init__(self, inner: nn.Module, input_shape, output_shape):
- super().__init__()
- self.inner = inner
- self.input_shape = input_shape
- self.output_shape = output_shape
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- out = self.inner(x.flatten(1))
- return out.reshape(x.shape[0], *self.output_shape)
diff --git a/src/lightning_trainable/trainable/lr_schedulers/__init__.py b/src/lightning_trainable/trainable/lr_schedulers/__init__.py
new file mode 100644
index 0000000..9234965
--- /dev/null
+++ b/src/lightning_trainable/trainable/lr_schedulers/__init__.py
@@ -0,0 +1 @@
+from .configure import configure
diff --git a/src/lightning_trainable/trainable/lr_schedulers/configure.py b/src/lightning_trainable/trainable/lr_schedulers/configure.py
new file mode 100644
index 0000000..89c0307
--- /dev/null
+++ b/src/lightning_trainable/trainable/lr_schedulers/configure.py
@@ -0,0 +1,40 @@
+
+from torch.optim.lr_scheduler import LRScheduler
+
+from lightning_trainable.utils import get_scheduler
+
+from . import defaults
+
+
+def with_kwargs(model, optimizer, **kwargs) -> LRScheduler:
+ """
+ Get a learning rate scheduler with the given kwargs
+ Insert default values for missing kwargs
+ @param model: Trainable model
+ @param optimizer: The optimizer to use with the scheduler
+ @param kwargs: Keyword arguments for the scheduler
+ @return: The configured learning rate scheduler
+ """
+ name = kwargs.pop("name")
+ default_kwargs = defaults.get_defaults(name, model, optimizer)
+ kwargs = default_kwargs | kwargs
+ return get_scheduler(name)(optimizer, **kwargs)
+
+
+def configure(model, optimizer) -> LRScheduler | None:
+ """
+ Configure a learning rate scheduler from the model's hparams
+ @param model: Trainable model
+ @param optimizer: The optimizer to use with the scheduler
+ @return: The configured learning rate scheduler
+ """
+ match model.hparams.lr_scheduler:
+ case str() as name:
+ return with_kwargs(model, optimizer, name=name)
+ case dict() as kwargs:
+ return with_kwargs(model, optimizer, **kwargs.copy())
+ case None:
+ # do not use a learning rate scheduler
+ return None
+ case other:
+ raise NotImplementedError(f"Unrecognized Scheduler: '{other}'")
diff --git a/src/lightning_trainable/trainable/lr_schedulers/defaults.py b/src/lightning_trainable/trainable/lr_schedulers/defaults.py
new file mode 100644
index 0000000..18a672f
--- /dev/null
+++ b/src/lightning_trainable/trainable/lr_schedulers/defaults.py
@@ -0,0 +1,16 @@
+
+def get_defaults(scheduler_name, model, optimizer):
+ match scheduler_name:
+ case "OneCycleLR":
+ max_lr = optimizer.defaults["lr"]
+ total_steps = model.hparams.max_steps
+ if total_steps == -1:
+ total_steps = model.hparams.max_epochs * len(model.train_dataloader())
+ total_steps = int(total_steps / model.hparams.accumulate_batches)
+
+ return dict(
+ max_lr=max_lr,
+ total_steps=total_steps
+ )
+ case _:
+ return dict()
diff --git a/src/lightning_trainable/trainable/optimizers/__init__.py b/src/lightning_trainable/trainable/optimizers/__init__.py
new file mode 100644
index 0000000..9234965
--- /dev/null
+++ b/src/lightning_trainable/trainable/optimizers/__init__.py
@@ -0,0 +1 @@
+from .configure import configure
diff --git a/src/lightning_trainable/trainable/optimizers/configure.py b/src/lightning_trainable/trainable/optimizers/configure.py
new file mode 100644
index 0000000..3828d13
--- /dev/null
+++ b/src/lightning_trainable/trainable/optimizers/configure.py
@@ -0,0 +1,38 @@
+
+from torch.optim import Optimizer
+
+from lightning_trainable.utils import get_optimizer
+
+from . import defaults
+
+
+def with_kwargs(model, **kwargs) -> Optimizer:
+ """
+ Get an optimizer with the given kwargs
+ Insert default values for missing kwargs
+ @param model: Trainable model
+ @param kwargs: Keyword arguments for the optimizer
+ @return: The configured optimizer
+ """
+ name = kwargs.pop("name")
+ default_kwargs = defaults.get_defaults(name, model)
+ kwargs = default_kwargs | kwargs
+ return get_optimizer(name)(**kwargs)
+
+
+def configure(model) -> Optimizer | None:
+ """
+ Configure an optimizer from the model's hparams
+ @param model: Trainable model
+ @return: The configured optimizer
+ """
+ match model.hparams.optimizer:
+ case str() as name:
+ return with_kwargs(model, name=name)
+ case dict() as kwargs:
+ return with_kwargs(model, **kwargs.copy())
+ case None:
+ # do not use an optimizer
+ return None
+ case other:
+ raise NotImplementedError(f"Unrecognized Optimizer: '{other}'")
diff --git a/src/lightning_trainable/trainable/optimizers/defaults.py b/src/lightning_trainable/trainable/optimizers/defaults.py
new file mode 100644
index 0000000..5b9b23c
--- /dev/null
+++ b/src/lightning_trainable/trainable/optimizers/defaults.py
@@ -0,0 +1,7 @@
+
+def get_defaults(optimizer_name, model):
+ match optimizer_name:
+ case _:
+ return dict(
+ params=model.parameters(),
+ )
diff --git a/src/lightning_trainable/trainable/trainable.py b/src/lightning_trainable/trainable/trainable.py
index 210a7cb..9684043 100644
--- a/src/lightning_trainable/trainable/trainable.py
+++ b/src/lightning_trainable/trainable/trainable.py
@@ -1,18 +1,19 @@
-import os
-import pathlib
-from copy import deepcopy
-
import lightning
+import os
import torch
-from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, ProgressBar, EarlyStopping
+
+from copy import deepcopy
+from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import Logger, TensorBoardLogger
from torch.utils.data import DataLoader, Dataset, IterableDataset
from tqdm import tqdm
from lightning_trainable import utils
-from lightning_trainable.callbacks import EpochProgressBar
from .trainable_hparams import TrainableHParams
+from . import lr_schedulers
+from . import optimizers
+
class SkipBatch(Exception):
pass
@@ -33,15 +34,16 @@ def __init__(
if not isinstance(hparams, self.hparams_type):
hparams = self.hparams_type(**hparams)
self.save_hyperparameters(hparams)
- # workaround for https://github.com/Lightning-AI/lightning/issues/17889
- self._hparams_name = "hparams"
self.train_data = train_data
self.val_data = val_data
self.test_data = test_data
def __init_subclass__(cls, **kwargs):
- cls.hparams_type = cls.__annotations__.get("hparams", TrainableHParams)
+ hparams_type = cls.__annotations__.get("hparams")
+ if hparams_type is not None:
+ # only overwrite hparams_type if it is defined by the child class
+ cls.hparams_type = hparams_type
def compute_metrics(self, batch, batch_idx) -> dict:
"""
@@ -89,97 +91,18 @@ def test_step(self, batch, batch_idx):
self.log(f"test/{key}", value,
prog_bar=key == self.hparams.loss)
- def configure_lr_schedulers(self, optimizer):
- """
- Configure the LR Scheduler as defined in HParams.
- By default, we only use a single LR Scheduler, attached to a single optimizer.
- You can use a ChainedScheduler if you need multiple LR Schedulers throughout training,
- or override this method if you need different schedulers for different parameters.
-
- @param optimizer: The optimizer to attach the scheduler to.
- @return: The LR Scheduler object.
- """
- match self.hparams.lr_scheduler:
- case str() as name:
- match name.lower():
- case "onecyclelr":
- kwargs = dict(
- max_lr=optimizer.defaults["lr"],
- epochs=self.hparams.max_epochs,
- steps_per_epoch=len(self.train_dataloader())
- )
- interval = "step"
- case _:
- kwargs = dict()
- interval = "step"
- scheduler = utils.get_scheduler(name)(optimizer, **kwargs)
- return dict(
- scheduler=scheduler,
- interval=interval,
- )
- case dict() as kwargs:
- # Copy the dict so we don't modify the original
- kwargs = deepcopy(kwargs)
-
- name = kwargs.pop("name")
- interval = "step"
- if "interval" in kwargs:
- interval = kwargs.pop("interval")
- config_kwargs = kwargs.pop("config", dict())
- scheduler = utils.get_scheduler(name)(optimizer, **kwargs)
- return dict(
- scheduler=scheduler,
- interval=interval,
- **config_kwargs
- )
- case type(torch.optim.lr_scheduler.LRScheduler) as Scheduler:
- kwargs = dict()
- interval = "step"
- scheduler = Scheduler(optimizer, **kwargs)
- return dict(
- scheduler=scheduler,
- interval=interval,
- )
- case (torch.optim.lr_scheduler.LRScheduler() | torch.optim.lr_scheduler.ReduceLROnPlateau()) as scheduler:
- return dict(
- scheduler=scheduler,
- interval="step",
- )
- case None:
- # do not use a scheduler
- return None
- case other:
- raise NotImplementedError(f"Unrecognized Scheduler: {other}")
-
def configure_optimizers(self):
"""
- Configure Optimizer and LR Scheduler objects as defined in HParams.
- By default, we only use a single optimizer and an optional LR Scheduler.
- If you need multiple optimizers, override this method.
+ Configure the optimizer and learning rate scheduler for this model, based on the HParams.
+ This method is called automatically by the Lightning Trainer in module fitting.
- @return: A dictionary containing the optimizer and lr_scheduler.
- """
- kwargs = dict()
-
- match self.hparams.optimizer:
- case str() as name:
- optimizer = utils.get_optimizer(name)(self.parameters(), **kwargs)
- case dict() as kwargs:
- # Copy the dict so we don't modify the original
- kwargs = deepcopy(kwargs)
-
- name = kwargs.pop("name")
- optimizer = utils.get_optimizer(name)(self.parameters(), **kwargs)
- case type(torch.optim.Optimizer) as Optimizer:
- optimizer = Optimizer(self.parameters(), **kwargs)
- case torch.optim.Optimizer() as optimizer:
- pass
- case None:
- return None
- case other:
- raise NotImplementedError(f"Unrecognized Optimizer: {other}")
+ By default, we use one optimizer and zero or one learning rate scheduler.
+ If you want to use multiple optimizers or learning rate schedulers, you must override this method.
- lr_scheduler = self.configure_lr_schedulers(optimizer)
+ @return: A dictionary containing the optimizer and learning rate scheduler, if any.
+ """
+ optimizer = optimizers.configure(self)
+ lr_scheduler = lr_schedulers.configure(self, optimizer)
if lr_scheduler is None:
return optimizer
@@ -207,7 +130,6 @@ def configure_callbacks(self) -> list:
callbacks = [
ModelCheckpoint(**checkpoint_kwargs),
LearningRateMonitor(),
- EpochProgressBar(),
]
if self.hparams.early_stopping is not None:
if self.hparams.early_stopping["monitor"] == "auto":
@@ -263,28 +185,22 @@ def test_dataloader(self) -> DataLoader | list[DataLoader]:
num_workers=self.hparams.num_workers,
)
- def configure_logger(self, save_dir=os.getcwd(), **kwargs) -> Logger:
+ def configure_logger(self, logger_name: str = "TensorBoardLogger", **logger_kwargs) -> Logger:
"""
Instantiate the Logger used by the Trainer in module fitting.
By default, we use a TensorBoardLogger, but you can use any other logger of your choice.
- @param save_dir: The root directory in which all your experiments with
- different names and versions will be stored.
- @param kwargs: Keyword-Arguments to the Logger. Set `logger_name` to use a different logger than TensorBoardLogger.
+ @param logger_name: The name of the logger to use. Defaults to TensorBoardLogger.
+ @param logger_kwargs: Keyword-Arguments to the Logger. Set `logger_name` to use a different logger than TensorBoardLogger.
@return: The Logger object.
"""
- logger_kwargs = deepcopy(kwargs)
- logger_kwargs.update(dict(
- save_dir=save_dir,
- ))
- logger_name = logger_kwargs.pop("logger_name", "TensorBoardLogger")
+ logger_kwargs = logger_kwargs or {}
+ logger_kwargs.setdefault("save_dir", os.getcwd())
logger_class = utils.get_logger(logger_name)
if issubclass(logger_class, TensorBoardLogger):
- logger_kwargs["default_hp_metric"] = False
+ logger_kwargs.setdefault("default_hp_metric", False)
- return logger_class(
- **logger_kwargs
- )
+ return logger_class(**logger_kwargs)
def configure_trainer(self, logger_kwargs: dict = None, trainer_kwargs: dict = None) -> lightning.Trainer:
"""
@@ -296,27 +212,20 @@ def configure_trainer(self, logger_kwargs: dict = None, trainer_kwargs: dict = N
See also :func:`~trainable.Trainable.configure_trainer`.
@return: The Lightning Trainer object.
"""
- if logger_kwargs is None:
- logger_kwargs = dict()
- if trainer_kwargs is None:
- trainer_kwargs = dict()
-
- if "enable_progress_bar" not in trainer_kwargs:
- if any(isinstance(callback, ProgressBar) for callback in self.configure_callbacks()):
- trainer_kwargs["enable_progress_bar"] = False
-
- return lightning.Trainer(
- accelerator=self.hparams.accelerator.lower(),
- logger=self.configure_logger(**logger_kwargs),
- devices=self.hparams.devices,
- max_epochs=self.hparams.max_epochs,
- max_steps=self.hparams.max_steps,
- gradient_clip_val=self.hparams.gradient_clip,
- accumulate_grad_batches=self.hparams.accumulate_batches,
- profiler=self.hparams.profiler,
- benchmark=True,
- **trainer_kwargs,
- )
+ logger_kwargs = logger_kwargs or {}
+ trainer_kwargs = trainer_kwargs or {}
+
+ trainer_kwargs.setdefault("accelerator", self.hparams.accelerator.lower())
+ trainer_kwargs.setdefault("accumulate_grad_batches", self.hparams.accumulate_batches)
+ trainer_kwargs.setdefault("benchmark", True)
+ trainer_kwargs.setdefault("devices", self.hparams.devices)
+ trainer_kwargs.setdefault("gradient_clip_val", self.hparams.gradient_clip)
+ trainer_kwargs.setdefault("logger", self.configure_logger(**logger_kwargs))
+ trainer_kwargs.setdefault("max_epochs", self.hparams.max_epochs)
+ trainer_kwargs.setdefault("max_steps", self.hparams.max_steps)
+ trainer_kwargs.setdefault("profiler", self.hparams.profiler)
+
+ return lightning.Trainer(**trainer_kwargs)
def on_before_optimizer_step(self, optimizer):
# who doesn't love breaking changes in underlying libraries
@@ -329,6 +238,14 @@ def on_before_optimizer_step(self, optimizer):
case other:
raise NotImplementedError(f"Unrecognized grad norm: {other}")
+ def on_train_start(self) -> None:
+ # get hparams metrics with a test batch
+ test_batch = next(iter(self.trainer.train_dataloader))
+ metrics = self.compute_metrics(test_batch, 0)
+
+ # add hparams to tensorboard
+ self.logger.log_hyperparams(self.hparams, metrics)
+
@torch.enable_grad()
def fit(self, logger_kwargs: dict = None, trainer_kwargs: dict = None, fit_kwargs: dict = None) -> dict:
"""
@@ -341,26 +258,18 @@ def fit(self, logger_kwargs: dict = None, trainer_kwargs: dict = None, fit_kwarg
@param fit_kwargs: Keyword-Arguments to the Trainer's fit method.
@return: Validation Metrics as defined in :func:`~trainable.Trainable.compute_metrics`.
"""
- if logger_kwargs is None:
- logger_kwargs = dict()
- if trainer_kwargs is None:
- trainer_kwargs = dict()
- if fit_kwargs is None:
- fit_kwargs = dict()
+ logger_kwargs = logger_kwargs or {}
+ trainer_kwargs = trainer_kwargs or {}
+ fit_kwargs = fit_kwargs or {}
trainer = self.configure_trainer(logger_kwargs, trainer_kwargs)
- metrics_list = trainer.validate(self)
- if metrics_list is not None and len(metrics_list) > 0:
- metrics = metrics_list[0]
- else:
- metrics = {}
- trainer.logger.log_hyperparams(self.hparams, metrics)
+
trainer.fit(self, **fit_kwargs)
return {
key: value.item()
for key, value in trainer.callback_metrics.items()
- if any(key.startswith(key) for key in ["training/", "validation/"])
+ if any(key.startswith(k) for k in ["training/", "validation/"])
}
@torch.enable_grad()
@@ -378,7 +287,14 @@ def fit_fast(self, device="cuda"):
self.train()
self.to(device)
- optimizer = self.configure_optimizers()["optimizer"]
+ maybe_optimizer = self.configure_optimizers()
+ if isinstance(maybe_optimizer, dict):
+ optimizer = maybe_optimizer["optimizer"]
+ elif isinstance(maybe_optimizer, torch.optim.Optimizer):
+ optimizer = maybe_optimizer
+ else:
+ raise RuntimeError("Invalid optimizer")
+
dataloader = self.train_dataloader()
loss = None
@@ -387,18 +303,12 @@ def fit_fast(self, device="cuda"):
batch = tuple(t.to(device) for t in batch if torch.is_tensor(t))
optimizer.zero_grad()
- loss = self.training_step(batch, 0)
+ loss = self.training_step(batch, batch_idx)
loss.backward()
optimizer.step()
return loss
- @classmethod
- def load_checkpoint(cls, root: str | pathlib.Path = "lightning_logs", version: int | str = "last",
- epoch: int | str = "last", step: int | str = "last", **kwargs):
- checkpoint = utils.find_checkpoint(root, version, epoch, step)
- return cls.load_from_checkpoint(checkpoint, **kwargs)
-
def auto_pin_memory(pin_memory: bool | None, accelerator: str):
if pin_memory is None:
diff --git a/src/lightning_trainable/trainable/trainable_hparams.py b/src/lightning_trainable/trainable/trainable_hparams.py
index 3b58994..d27a455 100644
--- a/src/lightning_trainable/trainable/trainable_hparams.py
+++ b/src/lightning_trainable/trainable/trainable_hparams.py
@@ -1,6 +1,7 @@
from lightning_trainable.hparams import HParams
from lightning.pytorch.profilers import Profiler
+from lightning_trainable.utils import deprecate
class TrainableHParams(HParams):
@@ -11,7 +12,7 @@ class TrainableHParams(HParams):
devices: int = 1
max_epochs: int | None
max_steps: int = -1
- optimizer: str | dict | None = "adam"
+ optimizer: str | dict | None = "Adam"
lr_scheduler: str | dict | None = None
batch_size: int
accumulate_batches: int = 1
@@ -31,7 +32,37 @@ class TrainableHParams(HParams):
@classmethod
def _migrate_hparams(cls, hparams):
if "accumulate_batches" in hparams and hparams["accumulate_batches"] is None:
+ deprecate("accumulate_batches changed default value: None -> 1")
hparams["accumulate_batches"] = 1
+
+ if "optimizer" in hparams:
+ match hparams["optimizer"]:
+ case str() as name:
+ if name == name.lower():
+ deprecate("optimizer name is now case-sensitive.")
+ if name == "adam":
+ hparams["optimizer"] = "Adam"
+ case dict() as kwargs:
+ name = kwargs["name"]
+ if name == name.lower():
+ deprecate("optimizer name is now case-sensitive.")
+ if name == "adam":
+ hparams["optimizer"]["name"] = "Adam"
+
+ if "lr_scheduler" in hparams:
+ match hparams["lr_scheduler"]:
+ case str() as name:
+ if name == name.lower():
+ deprecate("lr_scheduler name is now case-sensitive.")
+ if name == "onecyclelr":
+ hparams["lr_scheduler"] = "OneCycleLR"
+ case dict() as kwargs:
+ name = kwargs["name"]
+ if name == name.lower():
+ deprecate("lr_scheduler name is now case-sensitive.")
+ if name == "onecyclelr":
+ hparams["lr_scheduler"]["name"] = "OneCycleLR"
+
if "early_stopping" in hparams and isinstance(hparams["early_stopping"], int):
hparams["early_stopping"] = dict(monitor="auto", patience=hparams["early_stopping"])
return hparams
diff --git a/src/lightning_trainable/utils/__init__.py b/src/lightning_trainable/utils/__init__.py
index 426dc0f..def6be2 100644
--- a/src/lightning_trainable/utils/__init__.py
+++ b/src/lightning_trainable/utils/__init__.py
@@ -1,3 +1,4 @@
+from .deprecate import deprecate
from .io import find_checkpoint
from .iteration import flatten, zip
from .modules import (
diff --git a/src/lightning_trainable/utils/deprecate.py b/src/lightning_trainable/utils/deprecate.py
new file mode 100644
index 0000000..ff592d4
--- /dev/null
+++ b/src/lightning_trainable/utils/deprecate.py
@@ -0,0 +1,33 @@
+
+import warnings
+from functools import wraps
+import inspect
+
+
+def deprecate(message: str, version: str = None):
+
+ def wrapper(obj):
+ nonlocal message
+
+ if inspect.isclass(obj):
+ name = "Class"
+ elif inspect.isfunction(obj):
+ name = "Function"
+ elif inspect.ismethod(obj):
+ name = "Method"
+ else:
+ name = "Object"
+
+ if version is not None:
+ message = f"{name} '{obj.__name__}' is deprecated since version {version}: {message}"
+ else:
+ message = f"{name} '{obj.__name__}' is deprecated: {message}"
+
+ @wraps(obj)
+ def wrapped(*args, **kwargs):
+ warnings.warn(message, DeprecationWarning, stacklevel=2)
+ return obj(*args, **kwargs)
+
+ return wrapped
+
+ return wrapper
diff --git a/src/lightning_trainable/utils/io.py b/src/lightning_trainable/utils/io.py
index 73a0931..b44a4c2 100644
--- a/src/lightning_trainable/utils/io.py
+++ b/src/lightning_trainable/utils/io.py
@@ -9,6 +9,10 @@ def find_version(root: str | Path = "lightning_logs", version: int = "last") ->
# Determine latest version number if "last" is passed as version number
if version == "last":
version_folders = [f for f in root.iterdir() if f.is_dir() and re.match(r"^version_(\d+)$", f.name)]
+
+ if not version_folders:
+ raise FileNotFoundError(f"Found no folders in '{root}' matching the name pattern 'version_'")
+
version_numbers = [int(re.match(r"^version_(\d+)$", f.name).group(1)) for f in version_folders]
version = max(version_numbers)
@@ -25,7 +29,6 @@ def find_epoch_step(root: str | Path, epoch: int = "last", step: int = "last") -
@param step: Step number or "last"
@return: epoch and step numbers
"""
-
root = Path(root)
# get checkpoint filenames
@@ -36,6 +39,10 @@ def find_epoch_step(root: str | Path, epoch: int = "last", step: int = "last") -
# remove invalid files
checkpoints = [cp for cp in checkpoints if pattern.match(cp)]
+ if not checkpoints:
+ raise FileNotFoundError(f"Found no checkpoints in '{root}' matching "
+ f"the name pattern 'epoch=-step=.ckpt'")
+
# get epochs and steps as list
matches = [pattern.match(cp) for cp in checkpoints]
epochs, steps = zip(*[(int(match.group(1)), int(match.group(2))) for match in matches])
@@ -43,7 +50,8 @@ def find_epoch_step(root: str | Path, epoch: int = "last", step: int = "last") -
if epoch == "last":
epoch = max(epochs)
elif epoch not in epochs:
- raise FileNotFoundError(f"No checkpoint in '{root}' for epoch '{epoch}'.")
+ closest_epoch = min(epochs, key=lambda e: abs(e - epoch))
+ raise FileNotFoundError(f"No checkpoint in '{root}' for epoch '{epoch}'. Closest is '{closest_epoch}'.")
# keep only steps for this epoch
steps = [s for i, s in enumerate(steps) if epochs[i] == epoch]
@@ -51,7 +59,9 @@ def find_epoch_step(root: str | Path, epoch: int = "last", step: int = "last") -
if step == "last":
step = max(steps)
elif step not in steps:
- raise FileNotFoundError(f"No checkpoint in '{root}' for epoch '{epoch}', step '{step}'")
+ closest_step = min(steps, key=lambda s: abs(s - step))
+ raise FileNotFoundError(f"No checkpoint in '{root}' for epoch '{epoch}', step '{step}'. "
+ f"Closest step for this epoch is '{closest_step}'.")
return epoch, step
@@ -60,7 +70,7 @@ def find_checkpoint(root: str | Path = "lightning_logs", version: int = "last",
"""
Helper method to find a lightning checkpoint based on version, epoch and step numbers.
- @param root: logs root directory. Usually "lightning_logs/"
+ @param root: logs root directory. Usually "lightning_logs/" or "lightning_logs/"
@param version: version number or "last"
@param epoch: epoch number or "last"
@param step: step number or "last"
@@ -69,7 +79,7 @@ def find_checkpoint(root: str | Path = "lightning_logs", version: int = "last",
root = Path(root)
if not root.is_dir():
- raise ValueError(f"Root directory '{root}' does not exist")
+ raise ValueError(f"Checkpoint root directory '{root}' does not exist")
# get existing version number or error
version = find_version(root, version)
@@ -87,7 +97,4 @@ def find_checkpoint(root: str | Path = "lightning_logs", version: int = "last",
checkpoint = checkpoint_folder / f"epoch={epoch}-step={step}.ckpt"
- if not checkpoint.is_file():
- raise FileNotFoundError(f"{version=}, {epoch=}, {step=}")
-
return str(checkpoint)
diff --git a/tests/module_tests/__init__.py b/tests/module_tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/module_tests/test_trainable.py b/tests/module_tests/test_trainable.py
new file mode 100644
index 0000000..31283cf
--- /dev/null
+++ b/tests/module_tests/test_trainable.py
@@ -0,0 +1,145 @@
+
+import pytest
+
+import torch
+import torch.nn as nn
+from torch.utils.data import TensorDataset
+
+from pathlib import Path
+
+from lightning_trainable.hparams import HParams
+from lightning_trainable.trainable import Trainable, TrainableHParams
+from lightning_trainable.utils import find_checkpoint
+
+
+@pytest.fixture
+def dummy_dataset():
+ return TensorDataset(torch.randn(128, 2))
+
+
+@pytest.fixture
+def dummy_network():
+ return nn.Linear(2, 2)
+
+
+@pytest.fixture
+def dummy_hparams_cls():
+ class DummyHParams(TrainableHParams):
+ max_epochs: int = 10
+ batch_size: int = 4
+ accelerator: str = "cpu"
+
+ return DummyHParams
+
+
+@pytest.fixture
+def dummy_hparams(dummy_hparams_cls):
+ return dummy_hparams_cls()
+
+
+@pytest.fixture
+def dummy_model_cls(dummy_network, dummy_dataset, dummy_hparams_cls):
+ class DummyModel(Trainable):
+ hparams: dummy_hparams_cls
+
+ def __init__(self, hparams):
+ super().__init__(hparams, train_data=dummy_dataset, val_data=dummy_dataset, test_data=dummy_dataset)
+ self.network = dummy_network
+
+ def compute_metrics(self, batch, batch_idx) -> dict:
+ return dict(
+ loss=torch.tensor(0.0, requires_grad=True)
+ )
+
+ return DummyModel
+
+
+@pytest.fixture
+def dummy_model(dummy_model_cls, dummy_hparams):
+ return dummy_model_cls(dummy_hparams)
+
+
+def test_fit(dummy_model):
+ train_metrics = dummy_model.fit()
+
+ assert isinstance(train_metrics, dict)
+ assert "training/loss" in train_metrics
+ assert "validation/loss" in train_metrics
+
+
+def test_fit_fast(dummy_model):
+ loss = dummy_model.fit_fast(device="cpu")
+
+ assert torch.isclose(loss, torch.tensor(0.0))
+
+
+def test_hparams_copy(dummy_hparams, dummy_hparams_cls):
+ assert isinstance(dummy_hparams, HParams)
+ assert isinstance(dummy_hparams, TrainableHParams)
+ assert isinstance(dummy_hparams, dummy_hparams_cls)
+
+ hparams_copy = dummy_hparams.copy()
+
+ assert isinstance(hparams_copy, HParams)
+ assert isinstance(dummy_hparams, TrainableHParams)
+ assert isinstance(dummy_hparams, dummy_hparams_cls)
+
+
+def test_hparams_invariant(dummy_model_cls, dummy_hparams):
+ """ Ensure HParams are left unchanged after instantiation and training """
+ hparams = dummy_hparams.copy()
+
+ dummy_model1 = dummy_model_cls(hparams)
+
+ assert hparams == dummy_hparams
+
+ dummy_model1.fit()
+
+ assert hparams == dummy_hparams
+
+
+def test_checkpoint(dummy_model):
+ # TODO: temp directory
+
+ dummy_model.fit()
+
+ checkpoint = find_checkpoint()
+
+ assert Path(checkpoint).is_file()
+
+
+def test_nested_checkpoint(dummy_model_cls, dummy_hparams_cls):
+
+ class MyHParams(dummy_hparams_cls):
+ pass
+
+ class MyModel(dummy_model_cls):
+ hparams: MyHParams
+
+ def __init__(self, hparams):
+ super().__init__(hparams)
+
+ hparams = MyHParams()
+ model = MyModel(hparams)
+
+ assert model._hparams_name == "hparams"
+
+
+def test_continue_training(dummy_model):
+ print("Starting Training.")
+ dummy_model.fit()
+
+ print("Finished Training. Loading Checkpoint.")
+ checkpoint = find_checkpoint()
+
+ trained_model = dummy_model.__class__.load_from_checkpoint(checkpoint)
+
+ print("Continuing Training.")
+ trained_model.fit(
+ trainer_kwargs=dict(max_epochs=2 * dummy_model.hparams.max_epochs),
+ fit_kwargs=dict(ckpt_path=checkpoint)
+ )
+
+ print("Finished Continued Training.")
+
+ # TODO: add check that the model was actually trained for 2x epochs
diff --git a/tests/test_trainable.py b/tests/test_trainable.py
deleted file mode 100644
index b5e8ab3..0000000
--- a/tests/test_trainable.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import torch
-from torch.utils.data import TensorDataset, Dataset
-
-from lightning_trainable import Trainable, TrainableHParams
-
-
-def test_instantiate():
- hparams = TrainableHParams(max_epochs=10, batch_size=32)
- Trainable(hparams)
-
-
-def test_simple_model():
- class SimpleTrainable(Trainable):
- def __init__(self, hparams: TrainableHParams | dict,
- train_data: Dataset = None,
- val_data: Dataset = None,
- test_data: Dataset = None
- ):
- super().__init__(hparams, train_data, val_data, test_data)
- self.param = torch.nn.Parameter(torch.randn(8, 1))
-
- def compute_metrics(self, batch, batch_idx) -> dict:
- return {
- "loss": ((batch[0] @ self.param) ** 2).mean()
- }
-
- train_data = TensorDataset(torch.randn(128, 8))
-
- hparams = TrainableHParams(
- accelerator="cpu",
- max_epochs=10,
- batch_size=32,
- lr_scheduler="onecyclelr"
- )
- model = SimpleTrainable(hparams, train_data=train_data)
- model.fit()
-
-
-def test_double_train():
- class SimpleTrainable(Trainable):
- def __init__(self, hparams: TrainableHParams | dict,
- train_data: Dataset = None,
- val_data: Dataset = None,
- test_data: Dataset = None
- ):
- super().__init__(hparams, train_data, val_data, test_data)
- self.param = torch.nn.Parameter(torch.randn(8, 1))
-
- def compute_metrics(self, batch, batch_idx) -> dict:
- return {
- "loss": ((batch[0] @ self.param) ** 2).mean()
- }
-
- hparams = TrainableHParams(
- accelerator="cpu",
- max_epochs=1,
- batch_size=8,
- optimizer=dict(
- name="adam",
- lr=1e-3,
- )
- )
-
- train_data = TensorDataset(torch.randn(128, 8))
-
- t1 = SimpleTrainable(hparams, train_data=train_data, val_data=train_data)
- t1.fit()
-
- t2 = SimpleTrainable(hparams, train_data=train_data, val_data=train_data)
- t2.fit()
diff --git a/tests/test_unet.py b/tests/test_unet.py
deleted file mode 100644
index 069b707..0000000
--- a/tests/test_unet.py
+++ /dev/null
@@ -1,60 +0,0 @@
-
-import pytest
-
-import torch
-
-from lightning_trainable.modules import UNet, UNetHParams
-
-
-@pytest.mark.parametrize("skip_mode", ["add", "concat", "none"])
-@pytest.mark.parametrize("width", [32, 48, 64])
-@pytest.mark.parametrize("height", [32, 48, 64])
-@pytest.mark.parametrize("channels", [1, 3, 5])
-def test_basic(skip_mode, channels, width, height):
- hparams = UNetHParams(
- input_shape=(channels, height, width),
- output_shape=(channels, height, width),
- down_blocks=[
- dict(channels=[16, 32], kernel_sizes=[3, 3, 3]),
- dict(channels=[32, 64], kernel_sizes=[3, 3, 3]),
- ],
- up_blocks=[
- dict(channels=[64, 32], kernel_sizes=[3, 3, 3]),
- dict(channels=[32, 16], kernel_sizes=[3, 3, 3]),
- ],
- bottom_widths=[32, 32],
- skip_mode=skip_mode,
- activation="relu",
- )
- unet = UNet(hparams)
-
- x = torch.randn(1, *hparams.input_shape)
- y = unet(x)
- assert y.shape == (1, *hparams.output_shape)
-
-
-@pytest.mark.parametrize("skip_mode", ["concat", "none"])
-@pytest.mark.parametrize("width", [32, 48, 64])
-@pytest.mark.parametrize("height", [32, 48, 64])
-@pytest.mark.parametrize("channels", [1, 3, 5])
-def test_inconsistent_channels(skip_mode, channels, width, height):
- hparams = UNetHParams(
- input_shape=(channels, height, width),
- output_shape=(channels, height, width),
- down_blocks=[
- dict(channels=[16, 24], kernel_sizes=[3, 3, 3]),
- dict(channels=[48, 17], kernel_sizes=[3, 3, 3]),
- ],
- up_blocks=[
- dict(channels=[57, 31], kernel_sizes=[3, 3, 3]),
- dict(channels=[12, 87], kernel_sizes=[3, 3, 3]),
- ],
- bottom_widths=[32, 32],
- skip_mode=skip_mode,
- activation="relu",
- )
- unet = UNet(hparams)
-
- x = torch.randn(1, *hparams.input_shape)
- y = unet(x)
- assert y.shape == (1, *hparams.output_shape)