diff --git a/pyadjoint/checkpointing.py b/pyadjoint/checkpointing.py index d20a0b2f..bed12361 100644 --- a/pyadjoint/checkpointing.py +++ b/pyadjoint/checkpointing.py @@ -257,7 +257,10 @@ def process_operation(self, cp_action, bar, **kwargs): def _(self, cp_action, bar, functional=None, **kwargs): for step in cp_action: if self.mode == Mode.RECOMPUTE: - bar.next() + try: + bar.next() + except AttributeError: + pass # Get the blocks of the current step. current_step = self.tape.timesteps[step] for block in current_step: @@ -290,7 +293,10 @@ def _(self, cp_action, bar, functional=None, **kwargs): @process_operation.register(Reverse) def _(self, cp_action, bar, markings, functional=None, **kwargs): for step in cp_action: - bar.next() + try: + bar.next() + except AttributeError: + pass # Get the blocks of the current step. current_step = self.tape.timesteps[step] for block in reversed(current_step):