Skip to content

Commit

Permalink
Merge pull request #77 from aai-institute/feature/trainer-criterion-loss
Browse files Browse the repository at this point in the history
Feature/trainer criterion loss
  • Loading branch information
samuelburbulla authored Mar 1, 2024
2 parents ba1af69 + 1132f67 commit dc1acc3
Show file tree
Hide file tree
Showing 16 changed files with 159 additions and 223 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
- Add `Sampler`, `BoxSampler`, `UniformBoxSampler`, and `RegularGridSampler` classes.
- Moved `DataLoader` into the `fit` method of the `Trainer`.
Therefore, `Trainer.fit` expects an `OperatorDataset` now.
- A `Criterion` now enables stopping the training loop.
- The `plotting` module has been removed.

## 0.0.0 (2024-02-22)

Expand Down
26 changes: 15 additions & 11 deletions examples/selfsupervised.ipynb

Large diffs are not rendered by default.

4 changes: 0 additions & 4 deletions src/continuity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
content: Loss functions for physics-informed training.
url: pde/index.md
- title: Plotting
content: Plotting utilities.
url: plotting/index.md
- title: Trainer
content: Default training loop for operator models.
url: trainer/index.md
Expand Down
9 changes: 0 additions & 9 deletions src/continuity/plotting/__init__.py

This file was deleted.

89 changes: 0 additions & 89 deletions src/continuity/plotting/plot.py

This file was deleted.

44 changes: 21 additions & 23 deletions src/continuity/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
"""

from abc import ABC, abstractmethod
from typing import Optional, List, Dict
from typing import Optional, List
import matplotlib.pyplot as plt
from .logs import Logs


class Callback(ABC):
Expand All @@ -15,13 +16,12 @@ class Callback(ABC):
"""

@abstractmethod
def __call__(self, epoch, logs: Dict[str, float]):
def __call__(self, logs: Logs):
"""Callback function.
Called at the end of each epoch.
Args:
epoch: Current epoch.
logs: Dictionary of logs.
logs: Training logs.
"""
raise NotImplementedError

Expand All @@ -42,20 +42,17 @@ class PrintTrainingLoss(Callback):
def __init__(self):
super().__init__()

def __call__(self, epoch: int, logs: Dict[str, float]):
def __call__(self, logs: Logs):
"""Callback function.
Called at the end of each epoch.
Args:
epoch: Current epoch.
logs: Dictionary of logs.
logs: Training logs.
"""
loss_train = logs["loss/train"]
seconds_per_epoch = logs["seconds_per_epoch"]

print(
f"\rEpoch {epoch}: loss/train = {loss_train:.4e} "
f"({seconds_per_epoch:.3g} s/epoch)",
f"\rEpoch {logs.epoch}: "
f"loss/train = {logs.loss_train:.4e} "
f"({logs.seconds_per_epoch:.3f} s/epoch)",
end="",
)

Expand All @@ -72,28 +69,30 @@ class LearningCurve(Callback):
Callback to plot learning curve.
Args:
keys: List of keys to plot. Default is ["loss/train"].
keys: List of keys to plot. Default is ["loss_train"].
"""

def __init__(self, keys: Optional[List[str]] = None):
if keys is None:
keys = ["loss/train"]
keys = ["loss_train"]

self.keys = keys
self.on_train_begin()
super().__init__()

def __call__(self, epoch: int, logs: Dict[str, float]):
def __call__(self, logs: Logs):
"""Callback function.
Called at the end of each epoch.
Args:
epoch: Current epoch.
logs: Dictionary of logs.
logs: Training logs.
"""
for key in self.keys:
if key in logs:
self.losses[key].append(logs[key])
try:
val = logs.__getattribute__(key)
self.losses[key].append(val)
except AttributeError:
pass

def on_train_begin(self):
"""Called at the beginning of training."""
Expand Down Expand Up @@ -125,15 +124,14 @@ def __init__(self, trial):
self.trial = trial
super().__init__()

def __call__(self, epoch: int, logs: Dict[str, float]):
def __call__(self, logs: Logs):
"""Callback function.
Called at the end of each epoch.
Args:
epoch: Current epoch.
logs: Dictionary of logs.
logs: Training logs.
"""
self.trial.report(logs["loss/train"], step=epoch)
self.trial.report(logs.loss_train, step=logs.epoch)

def on_train_begin(self):
"""Called at the beginning of training."""
Expand Down
48 changes: 48 additions & 0 deletions src/continuity/trainer/criterion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
`continuity.trainer.criterion`
Stopping criterion for Trainer in Continuity.
"""

from abc import ABC, abstractmethod
from .logs import Logs


class Criterion(ABC):
"""
Stopping criterion base class for `fit` method of `Trainer`.
"""

@abstractmethod
def __call__(self, logs: Logs):
"""Evaluate stopping criterion.
Called at the end of each epoch.
Args:
logs: Training logs.
Returns:
bool: Whether to stop training.
"""
raise NotImplementedError


class TrainingLossCriterion(Criterion):
"""
Stopping criterion based on training loss.
"""

def __init__(self, threshold: float):
self.threshold = threshold

def __call__(self, logs: Logs):
"""Callback function.
Called at the end of each epoch.
Args:
logs: Training logs.
Returns:
bool: True if training loss is below threshold.
"""
return logs.loss_train < self.threshold
21 changes: 21 additions & 0 deletions src/continuity/trainer/logs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
`continuity.trainer.logs`
"""

from dataclasses import dataclass


@dataclass
class Logs:
"""
Logs for callbacks and criteria within Trainer in Continuity.
Attributes:
epoch: Current epoch.
loss_train: Training loss.
time: Time taken for epoch.
"""

epoch: int
loss_train: float
seconds_per_epoch: float
38 changes: 27 additions & 11 deletions src/continuity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
from continuity.data import OperatorDataset
from continuity.operators import Operator
from continuity.operators.losses import Loss, MSELoss
from continuity.trainer.callbacks import Callback, PrintTrainingLoss
from continuity.trainer.device import get_device
from .callbacks import Callback, PrintTrainingLoss
from .criterion import Criterion, TrainingLossCriterion
from .logs import Logs


class Trainer:
Expand All @@ -27,12 +29,12 @@ class Trainer:
optimizer = torch.optim.Adam(operator.parameters(), lr=1e-3)
loss_fn = MSELoss()
trainer = Trainer(operator, optimizer, loss_fn, device="cuda:0")
trainer.fit(data_loader, epochs=100)
trainer.fit(dataset, tol=1e-3, epochs=1000)
```
Args:
operator: Operator to be trained.
optimizer: Torch-like optimizer. Default is Adam.
optimizer: Torch-like optimizer. Default is Adam with learning rate 1e-3.
loss_fn: Loss function taking (op, x, u, y, v). Default is MSELoss.
device: Device to train on. Default is CPU.
verbose: Print model parameters and use PrintTrainingLoss callback by default. Default is True.
Expand Down Expand Up @@ -62,17 +64,21 @@ def __init__(
def fit(
self,
dataset: OperatorDataset,
epochs: int = 100,
tol: float = 1e-5,
epochs: int = 1000,
callbacks: Optional[List[Callback]] = None,
criterion: Optional[Criterion] = None,
batch_size: int = 32,
shuffle: bool = True,
):
"""Fit operator to data set.
Args:
dataset: Data set.
epochs: Number of epochs.
callbacks: List of callbacks.
tol: Tolerance for stopping criterion. Ignored if criterion is not None.
epochs: Maximum number of epochs.
callbacks: List of callbacks. Defaults to [PrintTrainingLoss] if verbose.
criterion: Stopping criterion. Defaults to TrainingLossCriteria(tol).
batch_size: Batch size.
shuffle: Shuffle data set.
"""
Expand All @@ -83,6 +89,10 @@ def fit(
else:
callbacks = []

# Default criterion
if criterion is None:
criterion = TrainingLossCriterion(tol)

# Print number of model parameters
if self.verbose:
num_params = sum(p.numel() for p in self.operator.parameters())
Expand Down Expand Up @@ -142,13 +152,19 @@ def closure(x=x, u=u, y=y, v=v):
loss_train /= len(data_loader)

# Callbacks
logs = {
"loss/train": loss_train,
"seconds_per_epoch": seconds_per_epoch,
}
logs = Logs(
epoch=epoch + 1,
loss_train=loss_train,
seconds_per_epoch=seconds_per_epoch,
)

for callback in callbacks:
callback(epoch + 1, logs)
callback(logs)

# Stopping criterion
if criterion is not None:
if criterion(logs):
break

# Call on_train_end
for callback in callbacks:
Expand Down
Loading

0 comments on commit dc1acc3

Please sign in to comment.