From a8927cacc66b3e1ca4d6cdac570e68b3df84b965 Mon Sep 17 00:00:00 2001 From: Daiane Iglesia Dolci <63597005+Ig-dolci@users.noreply.github.com> Date: Fri, 16 Feb 2024 16:05:28 +0000 Subject: [PATCH] Support checkpointing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add support for incremental checkpointing (e.g. revolve) in Pyadjoint. --------- Co-authored-by: David Ham Co-authored-by: Jørgen Schartum Dokken --- docs/source/documentation/pyadjoint_api.rst | 2 + pyadjoint/block.py | 2 +- pyadjoint/block_variable.py | 24 +- pyadjoint/checkpointing.py | 334 ++++++++++++++++++ pyadjoint/drivers.py | 26 +- pyadjoint/reduced_functional.py | 11 +- pyadjoint/tape.py | 290 +++++++++++++-- setup.py | 2 +- .../firedrake_adjoint/test_burgers_newton.py | 116 ++++-- 9 files changed, 728 insertions(+), 79 deletions(-) create mode 100644 pyadjoint/checkpointing.py diff --git a/docs/source/documentation/pyadjoint_api.rst b/docs/source/documentation/pyadjoint_api.rst index 1bfb0ee0..36fed246 100644 --- a/docs/source/documentation/pyadjoint_api.rst +++ b/docs/source/documentation/pyadjoint_api.rst @@ -19,6 +19,8 @@ Core classes .. automethod:: add_block .. automethod:: visualise .. autoproperty:: progress_bar + .. automethod:: end_timestep + .. automethod:: timestepper .. autoclass:: Block diff --git a/pyadjoint/block.py b/pyadjoint/block.py index e9929861..29dbc70b 100644 --- a/pyadjoint/block.py +++ b/pyadjoint/block.py @@ -50,7 +50,7 @@ def add_dependency(self, dep, no_duplicates=False): """ if not no_duplicates or dep.block_variable not in self._dependencies: - dep._ad_will_add_as_dependency() + dep.block_variable.will_add_as_dependency() self._dependencies.append(dep.block_variable) def get_dependencies(self): diff --git a/pyadjoint/block_variable.py b/pyadjoint/block_variable.py index 86002638..3b3fa25d 100644 --- a/pyadjoint/block_variable.py +++ b/pyadjoint/block_variable.py @@ -1,4 +1,4 @@ -from .tape import no_annotations +from .tape import no_annotations, get_working_tape class BlockVariable(object): @@ -16,6 +16,10 @@ def __init__(self, output): self.floating_type = False # Helper flag for use during tape traversals. self.marked_in_path = False + # By default assume the variable is created externally to the tape. + self.creation_timestep = -1 + # The timestep during which this variable was last used as an input. + self.last_use = -1 def add_adj_output(self, val): if self.adj_value is None: @@ -59,13 +63,23 @@ def saved_output(self): def will_add_as_dependency(self): overwrite = self.output._ad_will_add_as_dependency() - overwrite = False if overwrite is None else overwrite - self.save_output(overwrite=overwrite) + overwrite = bool(overwrite) + tape = get_working_tape() + if self.last_use < tape.latest_checkpoint: + self.save_output(overwrite=overwrite) + tape.add_to_checkpointable_state(self, self.last_use) + self.last_use = tape.latest_timestep def will_add_as_output(self): + tape = get_working_tape() + self.creation_timestep = tape.latest_timestep + self.last_use = self.creation_timestep overwrite = self.output._ad_will_add_as_output() - overwrite = True if overwrite is None else overwrite - self.save_output(overwrite=overwrite) + overwrite = bool(overwrite) + if not overwrite: + self._checkpoint = None + if tape._eagerly_checkpoint_outputs: + self.save_output() def __str__(self): return str(self.output) diff --git a/pyadjoint/checkpointing.py b/pyadjoint/checkpointing.py new file mode 100644 index 00000000..d20a0b2f --- /dev/null +++ b/pyadjoint/checkpointing.py @@ -0,0 +1,334 @@ +from enum import Enum +from functools import singledispatchmethod +from checkpoint_schedules import ( + Copy, Move, EndForward, EndReverse, Forward, Reverse, StorageType) +from checkpoint_schedules import Revolve, MultistageCheckpointSchedule + + +class CheckpointError(RuntimeError): + pass + + +class Mode(Enum): + """The mode of the checkpoint manager. + + RECORD: The forward model is being taped. + FINISHED_RECORDING: The forward model is finished being taped. + EVALUATED: The forward model was evaluated. + EXHAUSTED: The forward and the adjoint models were evaluated and the schedule has concluded. + RECOMPUTE: The forward model is being recomputed. + EVALUATE_ADJOINT: The adjoint model is being evaluated. + + """ + RECORD = 1 + FINISHED_RECORDING = 2 + EVALUATED = 3 + EXHAUSTED = 4 + RECOMPUTE = 5 + EVALUATE_ADJOINT = 6 + + +class CheckpointManager: + """Manage the executions of the forward and adjoint solvers. + + Args: + schedule (checkpoint_schedules.schedule): A schedule provided by the `checkpoint_schedules` package. + tape (Tape): A list of blocks :class:`Block` instances. + + Attributes: + tape (Tape): A list of blocks :class:`Block` instances. + _schedule (checkpoint_schedules.schedule): A schedule provided by the `checkpoint_schedules` package. + forward_schedule (list): A list of `checkpoint_schedules` actions used to manage the execution of the + forward model. + reverse_schedule (list): A list of `checkpoint_schedules` actions used to manage the execution of the + reverse model. + timesteps (int): The initial number of timesteps. + adjoint_evaluated (bool): A boolean indicating whether the adjoint model has been evaluated. + mode (Mode): The mode of the checkpoint manager. The possible modes are `RECORD`, `FINISHED_RECORDING`, + `EVALUATED`, `EXHAUSTED`, `RECOMPUTE`, and `EVALUATE_ADJOINT`. Additional information about the modes + can be found class:`Mode`. + _current_action (checkpoint_schedules.CheckpointAction): The current `checkpoint_schedules` action. + + """ + def __init__(self, schedule, tape): + if ( + not isinstance(schedule, Revolve) + and not isinstance(schedule, MultistageCheckpointSchedule) + ): + raise CheckpointError( + "Only Revolve and MultistageCheckpointSchedule schedules are supported." + ) + + if ( + schedule.uses_storage_type(StorageType.DISK) + and not tape._package_data + ): + raise CheckpointError( + "The schedule employs disk checkpointing but it is not configured." + ) + self.tape = tape + self._schedule = schedule + self.forward_schedule = [] + self.reverse_schedule = [] + self.timesteps = schedule.max_n + # This variable is used to indicate whether the adjoint model has been evaluated at the checkpoint. + self.adjoint_evaluated = False + self.mode = Mode.RECORD + self._current_action = next(self._schedule) + self.forward_schedule.append(self._current_action) + # Tell the tape to only checkpoint input data until told otherwise. + self.tape.latest_checkpoint = 0 + self.end_timestep(-1) + + def end_timestep(self, timestep): + """Mark the end of one timestep when taping the forward model. + + Args: + timestep (int): The current timestep. + """ + if self.mode == Mode.EVALUATED: + raise CheckpointError("Not enough timesteps in schedule.") + elif self.mode != Mode.RECORD: + raise CheckpointError(f"Cannot end timestep in {self.mode}") + while not self.process_taping(self._current_action, timestep + 1): + self._current_action = next(self._schedule) + self.forward_schedule.append(self._current_action) + + def end_taping(self): + """Process the end of the forward execution.""" + current_timestep = self.tape.latest_timestep + while self.mode != Mode.EVALUATED: + self.end_timestep(current_timestep) + current_timestep += 1 + + @singledispatchmethod + def process_taping(self, cp_action, timestep): + """Implement checkpointing schedule actions while taping. + + A single-dispatch generic function. + + Note: + To have more information about the `checkpoint_schedules`, please refer to the + `documentation `_. + Detailed descriptions of the actions used in the process taping can be found at the following links: + `Forward `_ and `End_Forward `_. + + Args: + cp_action (checkpoint_schedules.CheckpointAction): A checkpoint action obtained from the + `checkpoint_schedules`. + timestep (int): The current timestep. + + Returns: + bool: Returns `True` if the timestep is in the `checkpoint_schedules` action. + For example, if the `checkpoint_schedules` action is `Forward(0, 4, True, False, StorageType.DISK)`, + then timestep `0, 1, 2, 3` is considered within the action; timestep `4` is not considered within the + action and `False` is returned. + + Raises: + CheckpointError: If the checkpoint action is not supported. + """ + + raise CheckpointError(f"Unable to process {cp_action} while taping.") + + @process_taping.register(Forward) + def _(self, cp_action, timestep): + if timestep < (cp_action.n0): + raise CheckpointError( + "Timestep is before start of Forward action." + ) + + self.tape._eagerly_checkpoint_outputs = cp_action.write_adj_deps + + if timestep > cp_action.n0: + if cp_action.write_ics and timestep == (cp_action.n0 + 1): + # Stores the checkpoint data in RAM + # This data will be used to restart the forward model + # from the step `n0` in the reverse computations. + self.tape.timesteps[cp_action.n0].checkpoint() + + if not cp_action.write_adj_deps: + # Remove unnecessary variables from previous steps. + for var in self.tape.timesteps[timestep - 1].checkpointable_state: + var._checkpoint = None + for block in self.tape.timesteps[timestep - 1]: + # Remove unnecessary variables from previous steps. + for output in block.get_outputs(): + output._checkpoint = None + + if timestep in cp_action: + self.tape.get_blocks().append_step() + if cp_action.write_ics: + self.tape.latest_checkpoint = cp_action.n0 + return True + else: + return False + + @process_taping.register(EndForward) + def _(self, cp_action, timestep): + if timestep != self.timesteps: + raise CheckpointError( + "The correct number of forward steps has notbeen taken." + ) + self.mode = Mode.EVALUATED + return True + + def recompute(self, functional=None): + """Recompute the forward model. + + Args: + functional (BlockVariable): The functional to be evaluated. + """ + self.mode = Mode.RECOMPUTE + + with self.tape.progress_bar("Evaluating Functional", + max=self.timesteps) as bar: + # Restore the initial condition to advance the forward model + # from the step 0. + current_step = self.tape.timesteps[self.forward_schedule[0].n0] + current_step.restore_from_checkpoint() + for cp_action in self.forward_schedule: + self._current_action = cp_action + self.process_operation(cp_action, bar, functional=functional) + + def evaluate_adj(self, last_block, markings): + """Evaluate the adjoint model. + + Args: + last_block (int): The last block to be evaluated. + markings (bool): If `True`, then each `BlockVariable` of the current block will have set + `marked_in_path` attribute indicating whether their adjoint components are relevant for + computing the final target adjoint values. + """ + # Work out other cases when they arise. + if last_block != 0: + raise NotImplementedError( + "Only the first block can be evaluated at present." + ) + + if self.mode == Mode.RECORD: + # The declared timesteps were not exhausted while taping. + self.end_taping() + + if self.mode not in (Mode.EVALUATED, Mode.FINISHED_RECORDING): + raise CheckpointError("Evaluate Functional before calling gradient.") + + with self.tape.progress_bar("Evaluating Adjoint", + max=self.timesteps) as bar: + if self.adjoint_evaluated: + reverse_iterator = iter(self.reverse_schedule) + while not isinstance(self._current_action, EndReverse): + if not self.adjoint_evaluated: + self._current_action = next(self._schedule) + self.reverse_schedule.append(self._current_action) + else: + self._current_action = next(reverse_iterator) + self.process_operation(self._current_action, bar, markings=markings) + # Only set the mode after the first backward in order to handle + # that step correctly. + self.mode = Mode.EVALUATE_ADJOINT + + # Inform that the adjoint model has been evaluated. + self.adjoint_evaluated = True + + @singledispatchmethod + def process_operation(self, cp_action, bar, **kwargs): + """A function used to process the forward and adjoint executions. + This single-dispatch generic function is used in the `Blocks` + recomputation and adjoint evaluation with checkpointing. + + Note: + The documentation of the `checkpoint_schedules` actions is available + `here `_. + + Args: + cp_action (checkpoint_schedules.CheckpointAction): A checkpoint action obtained from the + `checkpoint_schedules`. + bar (progressbar.ProgressBar): A progress bar to display the progress of the reverse executions. + kwargs: Additional keyword arguments. + + Raises: + CheckpointError: If the checkpoint action is not supported. + """ + raise CheckpointError(f"Unable to process {cp_action}.") + + @process_operation.register(Forward) + def _(self, cp_action, bar, functional=None, **kwargs): + for step in cp_action: + if self.mode == Mode.RECOMPUTE: + bar.next() + # Get the blocks of the current step. + current_step = self.tape.timesteps[step] + for block in current_step: + block.recompute() + + if cp_action.write_ics: + if step == cp_action.n0: + for var in current_step.checkpointable_state: + if var.checkpoint: + current_step._checkpoint.update( + {var: var.checkpoint} + ) + if not cp_action.write_adj_deps: + next_step = self.tape.timesteps[step + 1] + # The checkpointable state set of the current step. + to_keep = next_step.checkpointable_state + if functional: + # `to_keep` holds informations of the blocks required + # for restarting the forward model from a step `n`. + to_keep = to_keep.union([functional.block_variable]) + for block in current_step: + # Remove unnecessary variables from previous steps. + for bv in block.get_outputs(): + if bv not in to_keep: + bv._checkpoint = None + # Remove unnecessary variables from previous steps. + for var in (current_step.checkpointable_state - to_keep): + var._checkpoint = None + + @process_operation.register(Reverse) + def _(self, cp_action, bar, markings, functional=None, **kwargs): + for step in cp_action: + bar.next() + # Get the blocks of the current step. + current_step = self.tape.timesteps[step] + for block in reversed(current_step): + block.evaluate_adj(markings=markings) + # Output variables are used for the last time when running + # backwards. + for block in current_step: + for var in block.get_outputs(): + var.checkpoint = None + var.reset_variables(("tlm",)) + if not var.is_control: + var.reset_variables(("adjoint", "hessian")) + if cp_action.clear_adj_deps: + to_keep = current_step.checkpointable_state + if functional: + to_keep = to_keep.union([functional.block_variable]) + for output in block.get_outputs(): + if output not in to_keep: + output._checkpoint = None + + @process_operation.register(Copy) + def _(self, cp_action, bar, **kwargs): + current_step = self.tape.timesteps[cp_action.n] + current_step.restore_from_checkpoint() + + @process_operation.register(Move) + def _(self, cp_action, bar, **kwargs): + current_step = self.tape.timesteps[cp_action.n] + current_step.restore_from_checkpoint() + current_step.delete_checkpoint() + + @process_operation.register(EndForward) + def _(self, cp_action, bar, **kwargs): + self.mode = Mode.EVALUATED + + @process_operation.register(EndReverse) + def _(self, cp_action, bar, **kwargs): + if self._schedule.is_exhausted: + self.mode = Mode.EXHAUSTED + else: + self.mode = Mode.EVALUATED diff --git a/pyadjoint/drivers.py b/pyadjoint/drivers.py index fb631d17..9bf03b0c 100644 --- a/pyadjoint/drivers.py +++ b/pyadjoint/drivers.py @@ -26,7 +26,8 @@ def compute_gradient(J, m, options=None, tape=None, adj_value=1.0): with stop_annotating(): with tape.marked_nodes(m): - tape.evaluate_adj(markings=True) + with marked_controls(m): + tape.evaluate_adj(markings=True) grads = [i.get_derivative(options=options) for i in m] return m.delist(grads) @@ -91,3 +92,26 @@ def solve_adjoint(J, tape=None, adj_value=1.0): with stop_annotating(): tape.evaluate_adj(markings=False) + + +class marked_controls: + """A context manager for marking controls. + + Note: + This is a context manager for marking whether the class:'BlockVariable' is + a control. On exiting the context, the class:'BlockVariable' that were + marked as controls are automatically unmarked. + + Args: + controls (list): A list of :class:`Control` to mark within the context manager. + """ + def __init__(self, controls): + self.controls = controls + + def __enter__(self): + for control in self.controls: + control.mark_as_control() + + def __exit__(self, *args): + for control in self.controls: + control.unmark_as_control() diff --git a/pyadjoint/reduced_functional.py b/pyadjoint/reduced_functional.py index 993615df..8b2af45f 100644 --- a/pyadjoint/reduced_functional.py +++ b/pyadjoint/reduced_functional.py @@ -206,10 +206,13 @@ def __call__(self, values): blocks = self.tape.get_blocks() with self.marked_controls(): with stop_annotating(): - for i in self.tape._bar("Evaluating functional").iter( - range(len(blocks)) - ): - blocks[i].recompute() + if self.tape._checkpoint_manager: + self.tape._checkpoint_manager.recompute(self.functional) + else: + for i in self.tape._bar("Evaluating functional").iter( + range(len(blocks)) + ): + blocks[i].recompute() # ReducedFunctional can result in a scalar or an assembled 1-form func_value = self.functional.block_variable.saved_output diff --git a/pyadjoint/tape.py b/pyadjoint/tape.py index 9814beb8..996527c1 100644 --- a/pyadjoint/tape.py +++ b/pyadjoint/tape.py @@ -6,7 +6,9 @@ from functools import wraps from itertools import chain from abc import ABC, abstractmethod - +from typing import Optional +from collections.abc import Iterable +from .checkpointing import CheckpointManager, CheckpointError _working_tape = None _annotation_enabled = False @@ -160,11 +162,13 @@ class Tape(object): """ __slots__ = ["_blocks", "_tf_tensors", "_tf_added_blocks", "_nodes", - "_tf_registered_blocks", "_bar", "_package_data"] + "_tf_registered_blocks", "_bar", "_package_data", + "_checkpoint_manager", "latest_checkpoint", + "_eagerly_checkpoint_outputs"] def __init__(self, blocks=None, package_data=None): # Initialize the list of blocks on the tape. - self._blocks = [] if blocks is None else blocks + self._blocks = TimeStepSequence(blocks=blocks) # Dictionary of TensorFlow tensors. Key is id(block). self._tf_tensors = {} # Keep a list of blocks that has been added to the TensorFlow graph @@ -174,12 +178,54 @@ def __init__(self, blocks=None, package_data=None): # Hook location for packages which need to store additional data on the # tape. Packages should store the data under a "packagename" key. self._package_data = package_data or {} + # Default to checkpointing all block variables. + self.latest_checkpoint = float("inf") + self._checkpoint_manager = None + # Whether to store the adjoint dependencies. + self._eagerly_checkpoint_outputs = False def clear_tape(self): + """Clear the tape.""" self.reset_variables() - self._blocks = [] + self._blocks = TimeStepSequence() for data in self._package_data.values(): data.clear() + self._checkpoint_manager = None + + @property + def latest_timestep(self): + """The current time step to which blocks will be added.""" + return max(len(self._blocks.steps) - 1, 0) + + def end_timestep(self): + """Mark the end of a timestep when taping the forward model.""" + if self._checkpoint_manager: + self._checkpoint_manager.end_timestep(self.latest_timestep) + else: + self._blocks.append_step() + + def timestepper(self, iterable): + """Return an iterator that advances the tape timestep. + + Note: + This method facilitates taping timestepping simulations so that recompute + checkpointing can be used on the tape. For example, a simulation with + 10 timesteps might use a timestepping loop of this form:: + + tape = get_working_tape() + + for timestep in tape.timestepper(range(10)): + ... + + This has the effect of calling `tape.end_timestep()` after each iteration. + + Args: + iterable (iterable): The iterable definining the sequence of timesteps. + + Returns: + TapeTimeStepper: An iterator that advances the tape timestep. + """ + return TapeTimeStepper(self, iterable) def reset_blocks(self): """Calls the Block.reset method of all blocks on the tape. @@ -200,6 +246,39 @@ def add_block(self, block): # len() is computed in constant time, so this should be fine. return len(self._blocks) - 1 + def add_to_checkpointable_state(self, block_var, last_used): + """Add a block variable into the checkpointable state set. + + Note: + `checkpointable_state` is a set of block variables which are needed + to restart from the start of a timestep. + + Args: + block_var (BlockVariable): The block variable to add. + last_used (int): The last timestep in which the block variable was used. + """ + if not self.timesteps: + self._blocks.append_step() + for step in self.timesteps[last_used + 1:]: + step.checkpointable_state.add(block_var) + + def enable_checkpointing(self, schedule): + """Enable checkpointing on the adjoint evaluation. + + A checkpoint manager able to execute the forward and adjoint computations + according to the schedule provided by checkpoint_schedules package. + + Args: + schedule (checkpoint_schedules.schedule): A schedule provided by the + checkpoint_schedules package. + max_n (int, optional): The number of total steps. + """ + if self._blocks: + raise CheckpointError( + "Checkpointing must be enabled before any blocks are added to the tape." + ) + self._checkpoint_manager = CheckpointManager(schedule, self) + def get_blocks(self, tag=None): """Returns a list of the blocks on the tape. @@ -226,10 +305,21 @@ def get_tags(self): return tags def evaluate_adj(self, last_block=0, markings=False): - for i in self._bar("Evaluating adjoint").iter( - range(len(self._blocks) - 1, last_block - 1, -1) - ): - self._blocks[i].evaluate_adj(markings=markings) + """Evaluate the adjoint of the tape. + + Args: + last_block (int, optional): The index of the last block to evaluate. + markings (bool, optional): If True, then each `BlockVariable` of the current block + will have set `marked_in_path` attribute indicating whether their adjoint + components are relevant for computing the final target adjoint values. + """ + if self._checkpoint_manager: + self._checkpoint_manager.evaluate_adj(last_block, markings) + else: + for i in self._bar("Evaluating adjoint").iter( + range(len(self._blocks) - 1, last_block - 1, -1) + ): + self._blocks[i].evaluate_adj(markings=markings) def evaluate_tlm(self): for i in self._bar("Evaluating TLM").iter( @@ -264,7 +354,7 @@ def copy(self): """ # TODO: Offer deepcopying. But is it feasible memory wise to copy all checkpoints? return Tape( - blocks=self._blocks[:], + blocks=self._blocks, package_data={k: v.copy() for k, v in self._package_data.items()} ) @@ -319,40 +409,60 @@ def optimize(self, controls=None, functionals=None): def optimize_for_controls(self, controls): # TODO: Consider if we want Enlist wherever it is possible. Like in this case. # TODO: Consider warning/message on empty tape. - blocks = self.get_blocks() nodes = set([control.block_variable for control in controls]) - valid_blocks = [] - - for block in blocks: - depends_on_control = False - for dep in block.get_dependencies(): - if dep in nodes: - depends_on_control = True - - if depends_on_control: - for output in block.get_outputs(): - if output in nodes: - raise RuntimeError("Control depends on another control.") - nodes.add(output) - valid_blocks.append(block) - self._blocks = valid_blocks + discarded_variables = set() + optimized_timesteps = TimeStepSequence() + + for step in self._blocks.steps: + optimized_timesteps.append_step() + + for block in step: + depends_on_control = False + for dep in block.get_dependencies(): + if dep in nodes: + depends_on_control = True + break + + if depends_on_control: + for output in block.get_outputs(): + if output in nodes: + raise RuntimeError("Control depends on another control.") + nodes.add(output) + optimized_timesteps.append(block) + else: + discarded_variables.union(block.get_outputs()) + optimized_timesteps.steps[-1].checkpointable_state = \ + step.checkpointable_state - discarded_variables + + self._blocks = optimized_timesteps def optimize_for_functionals(self, functionals): - blocks = self.get_blocks() - nodes = set([functional.block_variable for functional in functionals]) - valid_blocks = [] + retained_nodes = set([functional.block_variable + for functional in functionals] + ) + optimized_timesteps = [] + + for step in reversed(self._blocks.steps): + current_blocks = [] + for block in reversed(step): + produces_functional = False + for dep in block.get_outputs(): + if dep in retained_nodes: + produces_functional = True + + if produces_functional: + for dep in block.get_dependencies(): + retained_nodes.add(dep) + current_blocks.append(block) + optimized_timesteps.append(TimeStep(reversed(current_blocks))) - for block in reversed(blocks): - produces_functional = False - for dep in block.get_outputs(): - if dep in nodes: - produces_functional = True + optimized_timesteps.reverse() - if produces_functional: - for dep in block.get_dependencies(): - nodes.add(dep) - valid_blocks.append(block) - self._blocks = list(reversed(valid_blocks)) + for step, new_step in zip(self._blocks.steps, optimized_timesteps): + new_step.checkpointable_state = \ + step.checkpointable_state & retained_nodes + + self._blocks = TimeStepSequence(steps=optimized_timesteps) @contextmanager def marked_nodes(self, controls): @@ -363,6 +473,11 @@ def marked_nodes(self, controls): for node in nodes: node.marked_in_path = False + @property + def timesteps(self): + """Return the list of time steps on this tape.""" + return self._blocks.steps + def _valid_tf_scope_name(self, name): """Return a valid TensorFlow scope name""" valid_name = "" @@ -596,6 +711,105 @@ def iter(self, iterator): return iterator +class TapeTimeStepper: + """Iterator wrapper which advances the timestep after each iteration.""" + def __init__(self, tape, iterable): + self.tape = tape + self.iterator = tape.progress_bar("Taping forward").iter(iterable) + self._first = True + + def __iter__(self): + return self + + def __next__(self): + if self._first: + self._first = False + else: + self.tape.end_timestep() + return next(self.iterator) + + +class TimeStep(list): + """A list of blocks in a single time step, plus associated metadata.""" + def __init__(self, blocks=()): + super().__init__(blocks) + # The set of block variables which are needed to restart from the start + # of this timestep. + self.checkpointable_state = set() + # A dictionary mapping the block variables in the checkpointable state + # to their checkpoint values. + self._checkpoint = {} + + def copy(self, blocks=None): + out = TimeStep(blocks or self) + out.checkpointable_state = self.checkpointable_state + return out + + def checkpoint(self): + """Store a copy of the checkpoints in the checkpointable state.""" + + with stop_annotating(): + self._checkpoint = { + var: var.saved_output._ad_create_checkpoint() + for var in self.checkpointable_state + } + + def restore_from_checkpoint(self): + """Restore the block var checkpoints from the timestep checkpoint.""" + + for var in self._checkpoint: + var.checkpoint = self._checkpoint[var] + + def delete_checkpoint(self): + """Delete the stored checkpoint references.""" + self._checkpoint = {} + + +class TimeStepSequence(list): + """A list of Blocks separated into timesteps to facilitate checkpointing. + + This behaves like a list of blocks. To access a list of the timesteps, use + the :attr:`steps` property. + """ + + def __init__(self, blocks=None, steps: Optional[Iterable[Iterable[TimeStep]]] = None): + # Keep both per-timestep and unified block lists. + if steps and blocks: + raise ValueError("set blocks or steps but not both.") + elif isinstance(blocks, TimeStepSequence): + self._steps = [step.copy() for step in blocks._steps] + elif blocks: + self._steps = [TimeStep(blocks)] + else: + self._steps = list(step.copy() for step in steps) if steps else [] + super().__init__(chain.from_iterable(self._steps)) + + @property + def steps(self): + return self._steps + + def append(self, other): + """Add a new block to the sequence and to the current TimeStep.""" + if not self.steps: + self.append_step() + self._steps[-1].append(other) + super().append(other) + + def append_step(self, step=None): + """Add a new TimeStep.""" + self._steps.append(step or TimeStep()) + + def __setitem__(self, key, value): + raise ValueError( + "Unable to set arbitrary blocks. Try appending instead." + ) + + def __delitem__(self, key, value): + raise ValueError( + "Unable to delete blocks from sequence." + ) + + class TapePackageData(ABC): """Abstract base class for additional data that packages store on the tape. diff --git a/setup.py b/setup.py index 73e521cd..c94b7e63 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,6 @@ package_dir={'pyadjoint': 'pyadjoint', 'firedrake_adjoint': 'firedrake_adjoint', 'numpy_adjoint': 'numpy_adjoint'}, - install_requires=['scipy>=1.0'], + install_requires=['scipy>=1.0', 'checkpoint-schedules'], extras_require=extras ) diff --git a/tests/firedrake_adjoint/test_burgers_newton.py b/tests/firedrake_adjoint/test_burgers_newton.py index 9197bf28..e063296e 100644 --- a/tests/firedrake_adjoint/test_burgers_newton.py +++ b/tests/firedrake_adjoint/test_burgers_newton.py @@ -7,27 +7,29 @@ from firedrake import * from firedrake.adjoint import * - - +from checkpoint_schedules import Revolve, MultistageCheckpointSchedule +import numpy as np set_log_level(CRITICAL) - +continue_annotation() n = 30 mesh = UnitIntervalMesh(n) V = FunctionSpace(mesh, "CG", 2) +end = 0.3 +timestep = Constant(1.0/n) +steps = int(end/float(timestep)) + 1 + def Dt(u, u_, timestep): return (u - u_)/timestep -def J(ic, solve_type): + +def J(ic, solve_type, checkpointing): u_ = Function(V) u = Function(V) v = TestFunction(V) - + u_.assign(ic) nu = Constant(0.0001) - - timestep = Constant(1.0/n) - - F = (Dt(u, ic, timestep)*v + F = (Dt(u, u_, timestep)*v + u*u.dx(0)*v + nu*u.dx(0)*v.dx(0))*dx bc = DirichletBC(V, 0.0, "on_boundary") @@ -35,38 +37,94 @@ def J(ic, solve_type): if solve_type == "NLVS": problem = NonlinearVariationalProblem(F, u, bcs=bc) solver = NonlinearVariationalSolver(problem) - solver.solve() - else: - solve(F == 0, u, bc) - u_.assign(u) - t += float(timestep) - F = (Dt(u, u_, timestep)*v - + u*u.dx(0)*v + nu*u.dx(0)*v.dx(0))*dx - - end = 0.2 - while (t <= end): + tape = get_working_tape() + t += float(timestep) + for t in tape.timestepper(np.arange(t, end + t, float(timestep))): if solve_type == "NLVS": solver.solve() else: solve(F == 0, u, bc) u_.assign(u) - t += float(timestep) + return assemble(u_*u_*dx + ic*ic*dx), u_ + + +@pytest.mark.parametrize("solve_type, checkpointing", + [("solve", "Revolve"), + ("NLVS", "Revolve"), + ("solve", "Multistage"), + ("NLVS", "Multistage"), + ("solve", None), + ("NLVS", None), + ]) +def test_burgers_newton(solve_type, checkpointing): + """Adjoint-based gradient tests with and without checkpointing. + """ + tape = get_working_tape() + tape.progress_bar = ProgressBar + if checkpointing == "Revolve": + tape.enable_checkpointing(Revolve(steps, steps//3)) + if checkpointing == "Multistage": + tape.enable_checkpointing(MultistageCheckpointSchedule(steps, steps//3, 0)) + x, = SpatialCoordinate(mesh) + ic = project(sin(2.*pi*x), V) + val, _ = J(ic, solve_type, checkpointing) + if checkpointing: + assert len(tape.timesteps) == steps - return assemble(u_*u_*dx + ic*ic*dx) + Jhat = ReducedFunctional(val, Control(ic)) + dJ = Jhat.derivative() + # Recomputing the functional with a modified control variable + # before the recompute test. + Jhat(project(sin(pi*x), V)) -@pytest.mark.parametrize("solve_type", - ["solve", "NLVS"]) -def test_burgers_newton(solve_type): - x, = SpatialCoordinate(mesh) - ic = project(sin(2*pi*x), V) + # Recompute test + assert(np.allclose(Jhat(ic), val)) - val = J(ic, solve_type) - - Jhat = ReducedFunctional(val, Control(ic)) + dJbar = Jhat.derivative() + # Test recompute adjoint-based gradient + assert np.allclose(dJ.dat.data_ro[:], dJbar.dat.data_ro[:]) + # Taylor test h = Function(V) h.assign(1, annotate=False) assert taylor_test(Jhat, ic, h) > 1.9 + + +@pytest.mark.parametrize("solve_type, checkpointing", + [("solve", "Revolve"), + ("NLVS", "Revolve"), + ("solve", "Multistage"), + ("NLVS", "Multistage") + ]) +def test_checkpointing_validity(solve_type, checkpointing): + """Compare forward and backward results with and without checkpointing. + """ + # Without checkpointing + tape = get_working_tape() + tape.progress_bar = ProgressBar + x, = SpatialCoordinate(mesh) + ic = project(sin(2.*pi*x), V) + + val0, u0 = J(ic, solve_type, False) + Jhat = ReducedFunctional(val0, Control(ic)) + dJ0 = Jhat.derivative() + tape.clear_tape() + + # With checkpointing + tape.progress_bar = ProgressBar + if checkpointing == "Revolve": + tape.enable_checkpointing(Revolve(steps, steps//3)) + if checkpointing == "Multistage": + tape.enable_checkpointing(MultistageCheckpointSchedule(steps, steps//3, 0)) + x, = SpatialCoordinate(mesh) + ic = project(sin(2.*pi*x), V) + val1, u1 = J(ic, solve_type, True) + Jhat = ReducedFunctional(val1, Control(ic)) + dJ1 = Jhat.derivative() + assert len(tape.timesteps) == steps + assert np.allclose(val0, val1) + assert np.allclose(u0.dat.data_ro[:], u1.dat.data_ro[:]) + assert np.allclose(dJ0.dat.data_ro[:], dJ1.dat.data_ro[:])