Skip to content

Commit

Permalink
Merge pull request #177 from dolfin-adjoint/JHopeCollins/set_working_…
Browse files Browse the repository at this point in the history
…tape_decorator

Allow using `set_working_tape` as a function decorator
  • Loading branch information
JHopeCollins authored Nov 22, 2024
2 parents a8ee848 + fb3b5bd commit 5f46e16
Showing 1 changed file with 38 additions and 19 deletions.
57 changes: 38 additions & 19 deletions pyadjoint/tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import os
import re
import threading
from contextlib import contextmanager
from functools import wraps
from contextlib import contextmanager, ContextDecorator
from itertools import chain
from typing import Optional, Iterable
from abc import ABC, abstractmethod
Expand All @@ -28,9 +27,13 @@ def continue_annotation():
return _annotation_enabled


class set_working_tape(object):
"""A context manager whithin which a new tape is set as the working tape.
This context manager can also be used in an imperative manner.
class set_working_tape(ContextDecorator):
"""Set a new tape as the working tape.
This class can be used in three ways:
1) as a free function to replace the working tape,
2) as a context manager within which a new tape is set as the working tape,
3) as a function decorator so that the new tape is set only inside the function.
Example usage:
Expand All @@ -48,6 +51,23 @@ class set_working_tape(object):
with set_working_tape() as tape:
...
3) Set the local tape inside a decorated function.
The two functions below are equivalent:
.. highlight:: python
.. code-block:: python
@set_working_tape()
def decorated_function(*args, **kwargs):
# do something here
return ReducedFunctional(functional, control)
def context_function(*args, **kwargs):
with set_working_tape():
# do something here
return ReducedFunctional(functional, control)
"""

def __init__(self, tape=None, **tape_kwargs):
Expand All @@ -68,8 +88,8 @@ def __exit__(self, *args):
_working_tape = self.old_tape


class stop_annotating(object):
"""A context manager within which annotation is stopped.
class stop_annotating(ContextDecorator):
"""A context manager and function decorator within which annotation is stopped.
Args:
modifies (OverloadedType or list[OverloadedType]): One or more
Expand All @@ -82,17 +102,23 @@ class stop_annotating(object):
modified variables at the end of the context manager. """

def __init__(self, modifies=None):
global _annotation_enabled
self.modifies = modifies
self._orig_annotation_enabled = _annotation_enabled
# the `no_annotations` context manager could be nested,
# so we need a stack to keep track of the original states.
self._orig_annotation_enabled = []

def __enter__(self):
global _annotation_enabled
if self.modifies and len(self._orig_annotation_enabled) != 0:
raise ValueError(
"Cannot use `modifies` argument if `stop_annotating` is nested,"
" e.g. if used as the `no_annotations` decorator.")
self._orig_annotation_enabled.append(_annotation_enabled)
_annotation_enabled = False

def __exit__(self, *args):
global _annotation_enabled
_annotation_enabled = self._orig_annotation_enabled
_annotation_enabled = self._orig_annotation_enabled.pop()
if self.modifies is not None:
try:
self.modifies.create_block_variable()
Expand All @@ -101,15 +127,8 @@ def __exit__(self, *args):
var.create_block_variable()


def no_annotations(function):
"""Decorator to turn off annotation for the decorated function."""

@wraps(function)
def wrapper(*args, **kwargs):
with stop_annotating():
return function(*args, **kwargs)

return wrapper
no_annotations = stop_annotating()
"""Decorator to turn off annotation for the decorated function."""


def annotate_tape(kwargs=None):
Expand Down

0 comments on commit 5f46e16

Please sign in to comment.