Skip to content

Commit

Permalink
checkpoint_schedules integration
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmaddison committed Aug 14, 2024
1 parent 7441732 commit 2c93891
Show file tree
Hide file tree
Showing 6 changed files with 358 additions and 18 deletions.
3 changes: 2 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@

html_css_files = ["custom.css"]

intersphinx_mapping = {"firedrake": ("https://www.firedrakeproject.org", None),
intersphinx_mapping = {"checkpoint_schedules": ("https://www.firedrakeproject.org/checkpoint_schedules", None), # noqa: E501
"firedrake": ("https://www.firedrakeproject.org", None),
"h5py": ("https://docs.h5py.org/en/stable", None),
"numpy": ("https://numpy.org/doc/stable", None),
"petsc4py": ("https://petsc.org/main/petsc4py", None),
Expand Down
1 change: 1 addition & 0 deletions docs/source/dependencies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ While not required, if available some features of tlm_adjoint use:

- `more-itertools <https://more-itertools.readthedocs.io>`_
- `Numba <https://numba.pydata.org>`_
- `checkpoint_schedules <https://www.firedrakeproject.org/checkpoint_schedules>`_
35 changes: 32 additions & 3 deletions tests/checkpoint_schedules/test_binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import functools
import pytest

try:
import checkpoint_schedules
except ModuleNotFoundError:
checkpoint_schedules = None
try:
import mpi4py.MPI as MPI
except ModuleNotFoundError:
Expand All @@ -16,7 +20,22 @@
reason="tests must be run in serial")


def checkpoint_schedules_multistage(
max_n, snapshots_in_ram, snapshots_on_disk, *,
trajectory="maximum"):
if checkpoint_schedules is None:
pytest.skip("checkpoint_schedules not available")

from tlm_adjoint.checkpoint_schedules.checkpoint_schedules \
import MultistageCheckpointSchedule
return MultistageCheckpointSchedule(
max_n, snapshots_in_ram, snapshots_on_disk,
trajectory=trajectory)


@pytest.mark.checkpoint_schedules
@pytest.mark.parametrize("schedule", [MultistageCheckpointSchedule,
checkpoint_schedules_multistage])
@pytest.mark.parametrize("trajectory", ["revolve",
"maximum"])
@pytest.mark.parametrize("n, S", [(1, (0,)),
Expand All @@ -25,7 +44,8 @@
(10, tuple(range(1, 10))),
(100, tuple(range(1, 100))),
(250, tuple(range(25, 250, 25)))])
def test_MultistageCheckpointSchedule(trajectory,
def test_MultistageCheckpointSchedule(schedule,
trajectory,
n, S):
@functools.singledispatch
def action(cp_action):
Expand Down Expand Up @@ -71,6 +91,9 @@ def action_forward(cp_action):
# No data for this step is stored
assert len(data.intersection(range(cp_action.n0, cp_action.n1))) == 0 # noqa: E501

# The forward is able to advance over these steps
assert replay is None or replay.issuperset(range(cp_action.n0, cp_action.n1)) # noqa: E501

model_n = cp_action.n1
model_steps += cp_action.n1 - cp_action.n0
if store_ics:
Expand Down Expand Up @@ -117,6 +140,8 @@ def action_read(cp_action):

ics.clear()
ics.update(cp[0])
replay.clear()
replay.update(cp[0])
model_n = cp_action.n

# Can advance the forward to the current location of the adjoint
Expand All @@ -143,9 +168,13 @@ def action_write(cp_action):

@action.register(EndForward)
def action_end_forward(cp_action):
nonlocal replay

# The correct number of forward steps has been taken
assert model_n == n

replay = set()

@action.register(EndReverse)
def action_end_reverse(cp_action):
# The correct number of adjoint steps has been taken
Expand All @@ -164,11 +193,11 @@ def action_end_reverse(cp_action):
ics = set()
store_data = False
data = set()
replay = None

snapshots = {}

cp_schedule = MultistageCheckpointSchedule(n, 0, s,
trajectory=trajectory)
cp_schedule = schedule(n, 0, s, trajectory=trajectory)
assert n == 1 or cp_schedule.uses_disk_storage
assert cp_schedule.n == 0
assert cp_schedule.r == 0
Expand Down
34 changes: 32 additions & 2 deletions tests/checkpoint_schedules/test_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
import functools
import pytest

try:
import checkpoint_schedules
except ModuleNotFoundError:
checkpoint_schedules = None
try:
import mpi4py.MPI as MPI
except ModuleNotFoundError:
Expand All @@ -17,14 +21,28 @@
reason="tests must be run in serial")


def checkpoint_schedules_mixed(max_n, snapshots, *, storage="disk"):
if checkpoint_schedules is None:
pytest.skip("checkpoint_schedules not available")

from tlm_adjoint.checkpoint_schedules.checkpoint_schedules \
import MixedCheckpointSchedule, StorageType
storage = {"RAM": StorageType.RAM,
"disk": StorageType.DISK}[storage]
return MixedCheckpointSchedule(max_n, snapshots, storage=storage)


@pytest.mark.checkpoint_schedules
@pytest.mark.parametrize("schedule", [MixedCheckpointSchedule,
checkpoint_schedules_mixed])
@pytest.mark.parametrize("n, S", [(1, (0,)),
(2, (1,)),
(3, (1, 2)),
(10, tuple(range(1, 10))),
(100, tuple(range(1, 100))),
(250, tuple(range(25, 250, 25)))])
def test_MixedCheckpointSchedule(n, S):
def test_MixedCheckpointSchedule(schedule,
n, S):
@functools.singledispatch
def action(cp_action):
raise TypeError("Unexpected action")
Expand Down Expand Up @@ -69,6 +87,9 @@ def action_forward(cp_action):
# No data for this step is stored
assert len(data.intersection(range(cp_action.n0, cp_action.n1))) == 0 # noqa: E501

# The forward is able to advance over these steps
assert replay is None or replay.issuperset(range(cp_action.n0, cp_action.n1)) # noqa: E501

model_n = cp_action.n1
model_steps += cp_action.n1 - cp_action.n0
if store_ics:
Expand Down Expand Up @@ -120,10 +141,14 @@ def action_read(cp_action):

ics.clear()
ics.update(cp[0])
replay.clear()
replay.update(cp[0])
model_n = cp_action.n

# Can advance the forward to the current location of the adjoint
assert ics.issuperset(range(model_n, n - model_r))
else:
replay.clear()

if len(cp[1]) > 0:
# Loading a non-linear dependency data checkpoint:
Expand Down Expand Up @@ -167,9 +192,13 @@ def action_write(cp_action):

@action.register(EndForward)
def action_end_forward(cp_action):
nonlocal replay

# The correct number of forward steps has been taken
assert model_n is not None and model_n == n

replay = set()

@action.register(EndReverse)
def action_end_reverse(cp_action):
# The correct number of adjoint steps has been taken
Expand All @@ -188,10 +217,11 @@ def action_end_reverse(cp_action):
ics = set()
store_data = False
data = set()
replay = None

snapshots = {}

cp_schedule = MixedCheckpointSchedule(n, s, storage="disk")
cp_schedule = schedule(n, s, storage="disk")
assert n == 1 or cp_schedule.uses_disk_storage
assert cp_schedule.n == 0
assert cp_schedule.r == 0
Expand Down
96 changes: 84 additions & 12 deletions tests/checkpoint_schedules/test_validity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
import functools
import pytest

try:
import checkpoint_schedules
except ModuleNotFoundError:
checkpoint_schedules = None
try:
import hrevolve
except ModuleNotFoundError:
Expand Down Expand Up @@ -46,19 +50,72 @@ def two_level(n, s, *, period):


def h_revolve(n, s):
if hrevolve is None:
pytest.skip("H-Revolve not available")
if s <= 1:
return (None,
{"RAM": 0, "disk": 0}, 0)
else:
return (HRevolveCheckpointSchedule(n, s // 2, s - (s // 2)),
{"RAM": s // 2, "disk": s - (s // 2)}, 1)
pytest.skip("Incompatible with schedule type")

return (HRevolveCheckpointSchedule(n, s // 2, s - (s // 2)),
{"RAM": s // 2, "disk": s - (s // 2)}, 1)


def mixed(n, s):
return (MixedCheckpointSchedule(n, s),
{"RAM": 0, "disk": s}, 1)


def checkpoint_schedules_memory(n, s):
if checkpoint_schedules is None:
pytest.skip("checkpoint_schedules not available")

from tlm_adjoint.checkpoint_schedules.checkpoint_schedules import \
SingleMemoryStorageSchedule
return (SingleMemoryStorageSchedule(),
{"RAM": 0, "disk": 0}, 1 + n)


def checkpoint_schedules_multistage(n, s):
if checkpoint_schedules is None:
pytest.skip("checkpoint_schedules not available")

from tlm_adjoint.checkpoint_schedules.checkpoint_schedules import \
MultistageCheckpointSchedule
return (MultistageCheckpointSchedule(n, 0, s),
{"RAM": 0, "disk": s}, 1)


def checkpoint_schedules_two_level(n, s, *, period):
if checkpoint_schedules is None:
pytest.skip("checkpoint_schedules not available")

from tlm_adjoint.checkpoint_schedules.checkpoint_schedules import \
StorageType, TwoLevelCheckpointSchedule
return (TwoLevelCheckpointSchedule(period, s, binomial_storage=StorageType.RAM), # noqa: E501
{"RAM": s, "disk": 1 + (n - 1) // period}, 1)


def checkpoint_schedules_h_revolve(n, s):
if checkpoint_schedules is None:
pytest.skip("checkpoint_schedules not available")
if s <= 1:
pytest.skip("Incompatible with schedule type")

from tlm_adjoint.checkpoint_schedules.checkpoint_schedules import \
HRevolve
return (HRevolve(n, s // 2, s - (s // 2)),
{"RAM": s // 2, "disk": s - (s // 2)}, 1)


def checkpoint_schedules_mixed(n, s):
if checkpoint_schedules is None:
pytest.skip("checkpoint_schedules not available")

from tlm_adjoint.checkpoint_schedules.checkpoint_schedules import \
MixedCheckpointSchedule
return (MixedCheckpointSchedule(n, s),
{"RAM": 0, "disk": s}, 1)


@pytest.mark.checkpoint_schedules
@pytest.mark.parametrize(
"schedule, schedule_kwargs",
Expand All @@ -72,11 +129,16 @@ def mixed(n, s):
(two_level, {"period": 2}),
(two_level, {"period": 7}),
(two_level, {"period": 10}),
pytest.param(
h_revolve, {},
marks=pytest.mark.skipif(hrevolve is None,
reason="H-Revolve not available")),
(mixed, {})])
(h_revolve, {}),
(mixed, {}),
(checkpoint_schedules_memory, {}),
(checkpoint_schedules_multistage, {}),
(checkpoint_schedules_two_level, {"period": 1}),
(checkpoint_schedules_two_level, {"period": 2}),
(checkpoint_schedules_two_level, {"period": 7}),
(checkpoint_schedules_two_level, {"period": 10}),
(checkpoint_schedules_h_revolve, {}),
(checkpoint_schedules_mixed, {})])
@pytest.mark.parametrize("n, S", [(1, (0,)),
(2, (1,)),
(3, (1, 2)),
Expand Down Expand Up @@ -126,6 +188,9 @@ def action_forward(cp_action):
# No non-linear dependency data for these steps is stored
assert len(data.intersection(range(cp_action.n0, n1))) == 0

# The forward is able to advance over these steps
assert replay is None or replay.issuperset(range(cp_action.n0, n1))

model_n = n1
if store_ics:
ics.update(range(cp_action.n0, n1))
Expand Down Expand Up @@ -170,10 +235,14 @@ def action_read(cp_action):
if len(cp[0]) > 0:
ics.clear()
ics.update(cp[0])
replay.clear()
replay.update(cp[0])
model_n = cp_action.n

# Can advance the forward to the current location of the adjoint
assert ics.issuperset(range(model_n, n - model_r))
else:
replay.clear()

if len(cp[1]) > 0:
data.clear()
Expand Down Expand Up @@ -202,9 +271,13 @@ def action_write(cp_action):

@action.register(EndForward)
def action_end_forward(cp_action):
nonlocal replay

# The correct number of forward steps has been taken
assert model_n is not None and model_n == n

replay = set()

@action.register(EndReverse)
def action_end_reverse(cp_action):
nonlocal model_r
Expand All @@ -225,12 +298,11 @@ def action_end_reverse(cp_action):
ics = set()
store_data = False
data = set()
replay = None

snapshots = {"RAM": {}, "disk": {}}

cp_schedule, storage_limits, data_limit = schedule(n, s, **schedule_kwargs) # noqa: E501
if cp_schedule is None:
pytest.skip("Incompatible with schedule type")
assert cp_schedule.n == 0
assert cp_schedule.r == 0
assert cp_schedule.max_n is None or cp_schedule.max_n == n
Expand Down
Loading

0 comments on commit 2c93891

Please sign in to comment.