From 6ef3daee48d4c28e06065e05cd0d7afa1979d047 Mon Sep 17 00:00:00 2001 From: Iglesia Dolci Date: Mon, 18 Mar 2024 08:23:02 +0000 Subject: [PATCH 1/2] wip --- pyadjoint/checkpointing.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pyadjoint/checkpointing.py b/pyadjoint/checkpointing.py index d20a0b2f..bed12361 100644 --- a/pyadjoint/checkpointing.py +++ b/pyadjoint/checkpointing.py @@ -257,7 +257,10 @@ def process_operation(self, cp_action, bar, **kwargs): def _(self, cp_action, bar, functional=None, **kwargs): for step in cp_action: if self.mode == Mode.RECOMPUTE: - bar.next() + try: + bar.next() + except AttributeError: + pass # Get the blocks of the current step. current_step = self.tape.timesteps[step] for block in current_step: @@ -290,7 +293,10 @@ def _(self, cp_action, bar, functional=None, **kwargs): @process_operation.register(Reverse) def _(self, cp_action, bar, markings, functional=None, **kwargs): for step in cp_action: - bar.next() + try: + bar.next() + except AttributeError: + pass # Get the blocks of the current step. current_step = self.tape.timesteps[step] for block in reversed(current_step): From 6bc1d7f3d9a6af2b417fb887e61953cf87f75363 Mon Sep 17 00:00:00 2001 From: Iglesia Dolci Date: Mon, 18 Mar 2024 12:07:09 +0000 Subject: [PATCH 2/2] Make the progress bar print optional. --- pyadjoint/checkpointing.py | 11 +++-------- pyadjoint/tape.py | 2 +- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/pyadjoint/checkpointing.py b/pyadjoint/checkpointing.py index bed12361..01809101 100644 --- a/pyadjoint/checkpointing.py +++ b/pyadjoint/checkpointing.py @@ -214,8 +214,7 @@ def evaluate_adj(self, last_block, markings): if self.mode not in (Mode.EVALUATED, Mode.FINISHED_RECORDING): raise CheckpointError("Evaluate Functional before calling gradient.") - with self.tape.progress_bar("Evaluating Adjoint", - max=self.timesteps) as bar: + with self.tape.progress_bar("Evaluating Adjoint", max=self.timesteps) as bar: if self.adjoint_evaluated: reverse_iterator = iter(self.reverse_schedule) while not isinstance(self._current_action, EndReverse): @@ -257,10 +256,8 @@ def process_operation(self, cp_action, bar, **kwargs): def _(self, cp_action, bar, functional=None, **kwargs): for step in cp_action: if self.mode == Mode.RECOMPUTE: - try: + if bar: bar.next() - except AttributeError: - pass # Get the blocks of the current step. current_step = self.tape.timesteps[step] for block in current_step: @@ -293,10 +290,8 @@ def _(self, cp_action, bar, functional=None, **kwargs): @process_operation.register(Reverse) def _(self, cp_action, bar, markings, functional=None, **kwargs): for step in cp_action: - try: + if bar: bar.next() - except AttributeError: - pass # Get the blocks of the current step. current_step = self.tape.timesteps[step] for block in reversed(current_step): diff --git a/pyadjoint/tape.py b/pyadjoint/tape.py index 0774d260..98dfe994 100644 --- a/pyadjoint/tape.py +++ b/pyadjoint/tape.py @@ -703,7 +703,7 @@ def __init__(self, *args, **kwargs): def __enter__(self): pass - def __exit__(self): + def __exit__(self, *args, **kwargs): pass def iter(self, iterator):