diff --git a/pyadjoint/tape.py b/pyadjoint/tape.py index 8e97f8fe..9814beb8 100644 --- a/pyadjoint/tape.py +++ b/pyadjoint/tape.py @@ -9,7 +9,7 @@ _working_tape = None -_stop_annotating = 0 +_annotation_enabled = False def get_working_tape(): @@ -17,14 +17,14 @@ def get_working_tape(): def pause_annotation(): - global _stop_annotating - _stop_annotating += 1 + global _annotation_enabled + _annotation_enabled = False def continue_annotation(): - global _stop_annotating - _stop_annotating -= 1 - return _stop_annotating <= 0 + global _annotation_enabled + _annotation_enabled = True + return _annotation_enabled class set_working_tape(object): @@ -81,13 +81,17 @@ class stop_annotating(object): modified variables at the end of the context manager. """ def __init__(self, modifies=None): + global _annotation_enabled self.modifies = modifies + self._orig_annotation_enabled = _annotation_enabled def __enter__(self): - pause_annotation() + global _annotation_enabled + _annotation_enabled = False def __exit__(self, *args): - continue_annotation() + global _annotation_enabled + _annotation_enabled = self._orig_annotation_enabled if self.modifies is not None: try: self.modifies.create_block_variable() @@ -125,7 +129,7 @@ def annotate_tape(kwargs=None): # TODO: Consider if there is any scenario where one would want the keyword to have # precedence over the global flag. - if _stop_annotating > 0: + if not _annotation_enabled: return False return annotate