From 2c9389120ea40e25599663b561f74228ddfb1e09 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Wed, 14 Aug 2024 11:50:15 +0100 Subject: [PATCH] checkpoint_schedules integration --- docs/source/conf.py | 3 +- docs/source/dependencies.rst | 1 + tests/checkpoint_schedules/test_binomial.py | 35 ++- tests/checkpoint_schedules/test_mixed.py | 34 ++- tests/checkpoint_schedules/test_validity.py | 96 +++++++- .../checkpoint_schedules.py | 207 ++++++++++++++++++ 6 files changed, 358 insertions(+), 18 deletions(-) create mode 100644 tlm_adjoint/checkpoint_schedules/checkpoint_schedules.py diff --git a/docs/source/conf.py b/docs/source/conf.py index d934de2f7..0d92cddc2 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -35,7 +35,8 @@ html_css_files = ["custom.css"] -intersphinx_mapping = {"firedrake": ("https://www.firedrakeproject.org", None), +intersphinx_mapping = {"checkpoint_schedules": ("https://www.firedrakeproject.org/checkpoint_schedules", None), # noqa: E501 + "firedrake": ("https://www.firedrakeproject.org", None), "h5py": ("https://docs.h5py.org/en/stable", None), "numpy": ("https://numpy.org/doc/stable", None), "petsc4py": ("https://petsc.org/main/petsc4py", None), diff --git a/docs/source/dependencies.rst b/docs/source/dependencies.rst index 94515eeba..0b91e0415 100644 --- a/docs/source/dependencies.rst +++ b/docs/source/dependencies.rst @@ -44,3 +44,4 @@ While not required, if available some features of tlm_adjoint use: - `more-itertools `_ - `Numba `_ + - `checkpoint_schedules `_ diff --git a/tests/checkpoint_schedules/test_binomial.py b/tests/checkpoint_schedules/test_binomial.py index 7c292a537..72fb22fb2 100644 --- a/tests/checkpoint_schedules/test_binomial.py +++ b/tests/checkpoint_schedules/test_binomial.py @@ -6,6 +6,10 @@ import functools import pytest +try: + import checkpoint_schedules +except ModuleNotFoundError: + checkpoint_schedules = None try: import mpi4py.MPI as MPI except ModuleNotFoundError: @@ -16,7 +20,22 @@ reason="tests must be run in serial") +def checkpoint_schedules_multistage( + max_n, snapshots_in_ram, snapshots_on_disk, *, + trajectory="maximum"): + if checkpoint_schedules is None: + pytest.skip("checkpoint_schedules not available") + + from tlm_adjoint.checkpoint_schedules.checkpoint_schedules \ + import MultistageCheckpointSchedule + return MultistageCheckpointSchedule( + max_n, snapshots_in_ram, snapshots_on_disk, + trajectory=trajectory) + + @pytest.mark.checkpoint_schedules +@pytest.mark.parametrize("schedule", [MultistageCheckpointSchedule, + checkpoint_schedules_multistage]) @pytest.mark.parametrize("trajectory", ["revolve", "maximum"]) @pytest.mark.parametrize("n, S", [(1, (0,)), @@ -25,7 +44,8 @@ (10, tuple(range(1, 10))), (100, tuple(range(1, 100))), (250, tuple(range(25, 250, 25)))]) -def test_MultistageCheckpointSchedule(trajectory, +def test_MultistageCheckpointSchedule(schedule, + trajectory, n, S): @functools.singledispatch def action(cp_action): @@ -71,6 +91,9 @@ def action_forward(cp_action): # No data for this step is stored assert len(data.intersection(range(cp_action.n0, cp_action.n1))) == 0 # noqa: E501 + # The forward is able to advance over these steps + assert replay is None or replay.issuperset(range(cp_action.n0, cp_action.n1)) # noqa: E501 + model_n = cp_action.n1 model_steps += cp_action.n1 - cp_action.n0 if store_ics: @@ -117,6 +140,8 @@ def action_read(cp_action): ics.clear() ics.update(cp[0]) + replay.clear() + replay.update(cp[0]) model_n = cp_action.n # Can advance the forward to the current location of the adjoint @@ -143,9 +168,13 @@ def action_write(cp_action): @action.register(EndForward) def action_end_forward(cp_action): + nonlocal replay + # The correct number of forward steps has been taken assert model_n == n + replay = set() + @action.register(EndReverse) def action_end_reverse(cp_action): # The correct number of adjoint steps has been taken @@ -164,11 +193,11 @@ def action_end_reverse(cp_action): ics = set() store_data = False data = set() + replay = None snapshots = {} - cp_schedule = MultistageCheckpointSchedule(n, 0, s, - trajectory=trajectory) + cp_schedule = schedule(n, 0, s, trajectory=trajectory) assert n == 1 or cp_schedule.uses_disk_storage assert cp_schedule.n == 0 assert cp_schedule.r == 0 diff --git a/tests/checkpoint_schedules/test_mixed.py b/tests/checkpoint_schedules/test_mixed.py index cb20a6f40..8d88c3558 100644 --- a/tests/checkpoint_schedules/test_mixed.py +++ b/tests/checkpoint_schedules/test_mixed.py @@ -7,6 +7,10 @@ import functools import pytest +try: + import checkpoint_schedules +except ModuleNotFoundError: + checkpoint_schedules = None try: import mpi4py.MPI as MPI except ModuleNotFoundError: @@ -17,14 +21,28 @@ reason="tests must be run in serial") +def checkpoint_schedules_mixed(max_n, snapshots, *, storage="disk"): + if checkpoint_schedules is None: + pytest.skip("checkpoint_schedules not available") + + from tlm_adjoint.checkpoint_schedules.checkpoint_schedules \ + import MixedCheckpointSchedule, StorageType + storage = {"RAM": StorageType.RAM, + "disk": StorageType.DISK}[storage] + return MixedCheckpointSchedule(max_n, snapshots, storage=storage) + + @pytest.mark.checkpoint_schedules +@pytest.mark.parametrize("schedule", [MixedCheckpointSchedule, + checkpoint_schedules_mixed]) @pytest.mark.parametrize("n, S", [(1, (0,)), (2, (1,)), (3, (1, 2)), (10, tuple(range(1, 10))), (100, tuple(range(1, 100))), (250, tuple(range(25, 250, 25)))]) -def test_MixedCheckpointSchedule(n, S): +def test_MixedCheckpointSchedule(schedule, + n, S): @functools.singledispatch def action(cp_action): raise TypeError("Unexpected action") @@ -69,6 +87,9 @@ def action_forward(cp_action): # No data for this step is stored assert len(data.intersection(range(cp_action.n0, cp_action.n1))) == 0 # noqa: E501 + # The forward is able to advance over these steps + assert replay is None or replay.issuperset(range(cp_action.n0, cp_action.n1)) # noqa: E501 + model_n = cp_action.n1 model_steps += cp_action.n1 - cp_action.n0 if store_ics: @@ -120,10 +141,14 @@ def action_read(cp_action): ics.clear() ics.update(cp[0]) + replay.clear() + replay.update(cp[0]) model_n = cp_action.n # Can advance the forward to the current location of the adjoint assert ics.issuperset(range(model_n, n - model_r)) + else: + replay.clear() if len(cp[1]) > 0: # Loading a non-linear dependency data checkpoint: @@ -167,9 +192,13 @@ def action_write(cp_action): @action.register(EndForward) def action_end_forward(cp_action): + nonlocal replay + # The correct number of forward steps has been taken assert model_n is not None and model_n == n + replay = set() + @action.register(EndReverse) def action_end_reverse(cp_action): # The correct number of adjoint steps has been taken @@ -188,10 +217,11 @@ def action_end_reverse(cp_action): ics = set() store_data = False data = set() + replay = None snapshots = {} - cp_schedule = MixedCheckpointSchedule(n, s, storage="disk") + cp_schedule = schedule(n, s, storage="disk") assert n == 1 or cp_schedule.uses_disk_storage assert cp_schedule.n == 0 assert cp_schedule.r == 0 diff --git a/tests/checkpoint_schedules/test_validity.py b/tests/checkpoint_schedules/test_validity.py index 25b375f89..b4341df1a 100644 --- a/tests/checkpoint_schedules/test_validity.py +++ b/tests/checkpoint_schedules/test_validity.py @@ -11,6 +11,10 @@ import functools import pytest +try: + import checkpoint_schedules +except ModuleNotFoundError: + checkpoint_schedules = None try: import hrevolve except ModuleNotFoundError: @@ -46,12 +50,13 @@ def two_level(n, s, *, period): def h_revolve(n, s): + if hrevolve is None: + pytest.skip("H-Revolve not available") if s <= 1: - return (None, - {"RAM": 0, "disk": 0}, 0) - else: - return (HRevolveCheckpointSchedule(n, s // 2, s - (s // 2)), - {"RAM": s // 2, "disk": s - (s // 2)}, 1) + pytest.skip("Incompatible with schedule type") + + return (HRevolveCheckpointSchedule(n, s // 2, s - (s // 2)), + {"RAM": s // 2, "disk": s - (s // 2)}, 1) def mixed(n, s): @@ -59,6 +64,58 @@ def mixed(n, s): {"RAM": 0, "disk": s}, 1) +def checkpoint_schedules_memory(n, s): + if checkpoint_schedules is None: + pytest.skip("checkpoint_schedules not available") + + from tlm_adjoint.checkpoint_schedules.checkpoint_schedules import \ + SingleMemoryStorageSchedule + return (SingleMemoryStorageSchedule(), + {"RAM": 0, "disk": 0}, 1 + n) + + +def checkpoint_schedules_multistage(n, s): + if checkpoint_schedules is None: + pytest.skip("checkpoint_schedules not available") + + from tlm_adjoint.checkpoint_schedules.checkpoint_schedules import \ + MultistageCheckpointSchedule + return (MultistageCheckpointSchedule(n, 0, s), + {"RAM": 0, "disk": s}, 1) + + +def checkpoint_schedules_two_level(n, s, *, period): + if checkpoint_schedules is None: + pytest.skip("checkpoint_schedules not available") + + from tlm_adjoint.checkpoint_schedules.checkpoint_schedules import \ + StorageType, TwoLevelCheckpointSchedule + return (TwoLevelCheckpointSchedule(period, s, binomial_storage=StorageType.RAM), # noqa: E501 + {"RAM": s, "disk": 1 + (n - 1) // period}, 1) + + +def checkpoint_schedules_h_revolve(n, s): + if checkpoint_schedules is None: + pytest.skip("checkpoint_schedules not available") + if s <= 1: + pytest.skip("Incompatible with schedule type") + + from tlm_adjoint.checkpoint_schedules.checkpoint_schedules import \ + HRevolve + return (HRevolve(n, s // 2, s - (s // 2)), + {"RAM": s // 2, "disk": s - (s // 2)}, 1) + + +def checkpoint_schedules_mixed(n, s): + if checkpoint_schedules is None: + pytest.skip("checkpoint_schedules not available") + + from tlm_adjoint.checkpoint_schedules.checkpoint_schedules import \ + MixedCheckpointSchedule + return (MixedCheckpointSchedule(n, s), + {"RAM": 0, "disk": s}, 1) + + @pytest.mark.checkpoint_schedules @pytest.mark.parametrize( "schedule, schedule_kwargs", @@ -72,11 +129,16 @@ def mixed(n, s): (two_level, {"period": 2}), (two_level, {"period": 7}), (two_level, {"period": 10}), - pytest.param( - h_revolve, {}, - marks=pytest.mark.skipif(hrevolve is None, - reason="H-Revolve not available")), - (mixed, {})]) + (h_revolve, {}), + (mixed, {}), + (checkpoint_schedules_memory, {}), + (checkpoint_schedules_multistage, {}), + (checkpoint_schedules_two_level, {"period": 1}), + (checkpoint_schedules_two_level, {"period": 2}), + (checkpoint_schedules_two_level, {"period": 7}), + (checkpoint_schedules_two_level, {"period": 10}), + (checkpoint_schedules_h_revolve, {}), + (checkpoint_schedules_mixed, {})]) @pytest.mark.parametrize("n, S", [(1, (0,)), (2, (1,)), (3, (1, 2)), @@ -126,6 +188,9 @@ def action_forward(cp_action): # No non-linear dependency data for these steps is stored assert len(data.intersection(range(cp_action.n0, n1))) == 0 + # The forward is able to advance over these steps + assert replay is None or replay.issuperset(range(cp_action.n0, n1)) + model_n = n1 if store_ics: ics.update(range(cp_action.n0, n1)) @@ -170,10 +235,14 @@ def action_read(cp_action): if len(cp[0]) > 0: ics.clear() ics.update(cp[0]) + replay.clear() + replay.update(cp[0]) model_n = cp_action.n # Can advance the forward to the current location of the adjoint assert ics.issuperset(range(model_n, n - model_r)) + else: + replay.clear() if len(cp[1]) > 0: data.clear() @@ -202,9 +271,13 @@ def action_write(cp_action): @action.register(EndForward) def action_end_forward(cp_action): + nonlocal replay + # The correct number of forward steps has been taken assert model_n is not None and model_n == n + replay = set() + @action.register(EndReverse) def action_end_reverse(cp_action): nonlocal model_r @@ -225,12 +298,11 @@ def action_end_reverse(cp_action): ics = set() store_data = False data = set() + replay = None snapshots = {"RAM": {}, "disk": {}} cp_schedule, storage_limits, data_limit = schedule(n, s, **schedule_kwargs) # noqa: E501 - if cp_schedule is None: - pytest.skip("Incompatible with schedule type") assert cp_schedule.n == 0 assert cp_schedule.r == 0 assert cp_schedule.max_n is None or cp_schedule.max_n == n diff --git a/tlm_adjoint/checkpoint_schedules/checkpoint_schedules.py b/tlm_adjoint/checkpoint_schedules/checkpoint_schedules.py new file mode 100644 index 000000000..a6e429baa --- /dev/null +++ b/tlm_adjoint/checkpoint_schedules/checkpoint_schedules.py @@ -0,0 +1,207 @@ +"""Translation between checkpointing schedules provided by the +checkpoint_schedules library and a tlm_adjoint :class:`.CheckpointSchedule`. + +Wrapped :class:`checkpoint_schedule.CheckpointSchedule` classes can be +imported from this module and then passed to :func:`.configure_checkpointing`, +e.g. + +.. code-block:: python + + from tlm_adjoint import configure_checkpointing + from tlm_adjoint.checkpoint_schedules.checkpoint_schedules \ + import MultistageCheckpointSchedule + + configure_checkpointing( + MultistageCheckpointSchedule, + {"max_n": 30, "snapshots_in_ram": 0, "snapshots_on_disk": 3}) +""" + +try: + import checkpoint_schedules +except ModuleNotFoundError: + checkpoint_schedules = None +if checkpoint_schedules is not None: + from checkpoint_schedules import ( + CheckpointSchedule as _CheckpointSchedule, Forward as _Forward, + Reverse as _Reverse, Copy as _Copy, Move as _Move, + EndForward as _EndForward, EndReverse as _EndReverse) + from checkpoint_schedules import StorageType + +from .schedule import ( + CheckpointSchedule, Configure, Clear, Forward, Reverse, Read, Write, + EndForward, EndReverse) + +from functools import singledispatch, wraps + + +def translation(cls): + class Translation(CheckpointSchedule): + def __init__(self, *args, **kwargs): + self._cp_schedule = cls(*args, **kwargs) + super().__init__(self._cp_schedule.max_n) + self._is_exhausted = self._cp_schedule.is_exhausted + + def iter(self): + # Used to ensure that we do not finalize the wrapped scheduler + # while yielding actions associated with a single wrapped action. + # Prevents multiple finalization of the wrapped schedule. + def locked(fn): + @wraps(fn) + def wrapped_fn(cp_action): + max_n = self.max_n + yield from fn(cp_action) + if self.max_n != max_n: + self._cp_schedule.finalize(self.max_n) + return wrapped_fn + + @singledispatch + @locked + def action(cp_action): + raise TypeError(f"Unexpected action type: {type(cp_action)}") + yield None + + ics = (0, 0) + data = (0, 0) + replay = None + checkpoints = {StorageType.RAM: {}, StorageType.DISK: {}} + + def clear(): + nonlocal ics, data + + ics = (0, 0) + data = (0, 0) + yield Clear(True, True) + + def read(n, storage, *, delete): + nonlocal ics, data, replay + + replay, _ = ics, data = checkpoints[storage][n] + if delete: + del checkpoints[storage][n] + self._n = n + yield Read(n, {StorageType.RAM: "RAM", + StorageType.DISK: "disk"}[storage], delete) + + def write(n, storage): + checkpoints[storage][n] = (ics, data) + yield Write(n, {StorageType.RAM: "RAM", + StorageType.DISK: "disk"}[storage]) + + def input_output(n, from_storage, to_storage, *, delete): + if to_storage in {StorageType.RAM, StorageType.DISK}: + yield from clear() + yield from read(n, from_storage, delete=delete) + yield from write(n, to_storage) + yield from clear() + elif to_storage == StorageType.WORK: + yield from clear() + yield from read(n, from_storage, delete=delete) + else: + raise ValueError(f"Unexpected storage type: " + f"{to_storage}") + + @action.register(_Forward) + @locked + def action_forward(cp_action): + nonlocal ics, data + + yield from clear() + yield Configure(cp_action.write_ics, cp_action.write_adj_deps) + + if cp_action.write_ics: + ics = (cp_action.n0, cp_action.n1) + if cp_action.write_adj_deps: + data = (cp_action.n0, cp_action.n1) + if replay is not None and (cp_action.n0 < replay[0] or cp_action.n1 > replay[1]): # noqa: E501 + raise RuntimeError("Invalid checkpointing state") + self._n = cp_action.n1 + yield Forward(cp_action.n0, cp_action.n1) + + if cp_action.storage == StorageType.NONE: + if cp_action.write_ics or cp_action.write_adj_deps: + raise ValueError("Unexpected action parameters") + elif cp_action.storage in {StorageType.RAM, StorageType.DISK}: + yield from write(cp_action.n0, cp_action.storage) + yield from clear() + elif cp_action.storage == StorageType.WORK: + if cp_action.write_ics: + raise ValueError("Unexpected action parameters") + else: + raise ValueError(f"Unexpected storage type: " + f"{cp_action.storage}") + + @action.register(_Reverse) + @locked + def action_reverse(cp_action): + if self.max_n is None: + raise RuntimeError("Invalid checkpointing state") + if cp_action.n0 < data[0] or cp_action.n1 > data[1]: + raise RuntimeError("Invalid checkpointing state") + self._r = self.max_n - cp_action.n0 + yield Reverse(cp_action.n1, cp_action.n0) + yield from clear() + + @action.register(_Copy) + @locked + def action_copy(cp_action): + yield from input_output( + cp_action.n, cp_action.from_storage, cp_action.to_storage, + delete=False) + + @action.register(_Move) + @locked + def action_move(cp_action): + yield from input_output( + cp_action.n, cp_action.from_storage, cp_action.to_storage, + delete=True) + + @action.register(_EndForward) + @locked + def action_end_forward(cp_action): + self._is_exhausted = self._cp_schedule.is_exhausted + yield EndForward() + + @action.register(_EndReverse) + @locked + def action_end_reverse(cp_action): + if self._cp_schedule.is_exhausted: + yield from clear() + self._r = self._cp_schedule.r + self._is_exhausted = self._cp_schedule.is_exhausted + yield EndReverse(self._cp_schedule.is_exhausted) + + yield from clear() + while not self._cp_schedule.is_exhausted: + yield from action(next(self._cp_schedule)) + + @property + def is_exhausted(self): + return self._is_exhausted + + @property + def uses_disk_storage(self): + return self._cp_schedule.uses_storage_type(StorageType.DISK) + + return Translation + + +def _init(): + if checkpoint_schedules is not None: + __all__.append("StorageType") + + for name in dir(checkpoint_schedules): + obj = getattr(checkpoint_schedules, name) + if isinstance(obj, type) \ + and issubclass(obj, _CheckpointSchedule) \ + and obj is not _CheckpointSchedule: + globals()[name] = cls = translation(obj) + cls.__doc__ = (f"Wrapper for the checkpoint_schedules " + f":class:`checkpoint_schedules.{name}` class.") + __all__.append(name) + + __all__.sort() + + +__all__ = [] +_init() +del _init