Skip to content

Commit

Permalink
Merge pull request #420 from tlm-adjoint/jrmaddison/tlm_map_fix
Browse files Browse the repository at this point in the history
Do not allow replacement variables to be used to define a tangent-linear
  • Loading branch information
jrmaddison authored Oct 25, 2023
2 parents 0b8fee1 + 9548a9d commit d5492fe
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 34 deletions.
2 changes: 1 addition & 1 deletion tlm_adjoint/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def initialize(self, Js, blocks, transpose_deps, *,
if isinstance(eq, (ControlsMarker, FunctionalMarker)):
continue

eq_id = eq.id()
eq_id = eq.id
eq_tlm_root_id = getattr(eq, "_tlm_adjoint__tlm_root_id", eq_id) # noqa: E501
eq_tlm_key = getattr(eq, "_tlm_adjoint__tlm_key", ())

Expand Down
21 changes: 10 additions & 11 deletions tlm_adjoint/caches.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def clear_caches(*deps):
"""

if len(deps) == 0:
for cache in tuple(Cache._caches.valuerefs()):
cache = cache()
for cache_id in sorted(tuple(Cache._caches)):
cache = Cache._caches.get(cache_id, None)
if cache is not None:
cache.clear()
else:
Expand Down Expand Up @@ -94,7 +94,7 @@ def __init__(self):
self._deps_map = {}
self._dep_caches = {}

self._id = self._id_counter[0]
self._id, = self._id_counter
self._id_counter[0] += 1
self._caches[self._id] = self

Expand All @@ -108,10 +108,9 @@ def finalize_callback(cache):
def __len__(self):
return len(self._cache)

@property
def id(self):
"""Return the unique :class:`int` ID associated with this cache.
:returns: The unique :class:`int` ID.
"""A unique :class:`int` ID associated with this :class:`.Cache`.
"""

return self._id
Expand Down Expand Up @@ -280,27 +279,27 @@ def clear(self):
"""Clear cache entries which depend on the associated variable.
"""

for cache in tuple(self._caches.valuerefs()):
cache = cache()
for cache_id in sorted(tuple(self._caches)):
cache = self._caches.get(cache_id, None)
if cache is not None:
cache.clear(self._id)
assert not cache.id() in self._caches
assert cache.id not in self._caches

def add(self, cache):
"""Add a new :class:`.Cache` to the :class:`.Caches`.
:arg cache: The :class:`.Cache` to add to the :class:`.Caches`.
"""

self._caches.setdefault(cache.id(), cache)
self._caches.setdefault(cache.id, cache)

def remove(self, cache):
"""Remove a :class:`.Cache` from the :class:`.Caches`.
:arg cache: The :class:`.Cache` to remove from the :class:`.Caches`.
"""

del self._caches[cache.id()]
del self._caches[cache.id]

def update(self, x):
"""Check for cache invalidation associated with a possible change in
Expand Down
9 changes: 5 additions & 4 deletions tlm_adjoint/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ def __init__(self, referrers=None):
if referrers is None:
referrers = ()

self._id = self._id_counter[0]
self._id, = self._id_counter
self._id_counter[0] += 1
self._referrers = weakref.WeakValueDictionary()
self._references_dropped = False

self.add_referrer(*referrers)

@property
def id(self):
return self._id

Expand All @@ -47,22 +48,22 @@ def add_referrer(self, *referrers):
raise RuntimeError("Cannot call add_referrer method after "
"_drop_references method has been called")
for referrer in referrers:
referrer_id = referrer.id()
referrer_id = referrer.id
assert self._referrers.get(referrer_id, referrer) is referrer
self._referrers[referrer_id] = referrer

@gc_disabled
def referrers(self):
referrers = {}
remaining_referrers = {self.id(): self}
remaining_referrers = {self.id: self}
while len(remaining_referrers) > 0:
referrer_id, referrer = remaining_referrers.popitem()
if referrer_id not in referrers:
referrers[referrer_id] = referrer
for child in tuple(referrer._referrers.valuerefs()):
child = child()
if child is not None:
child_id = child.id()
child_id = child.id
if child_id not in referrers and child_id not in remaining_referrers: # noqa: E501
remaining_referrers[child_id] = child
return tuple(e for _, e in sorted(tuple(referrers.items()),
Expand Down
6 changes: 3 additions & 3 deletions tlm_adjoint/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class SerialComm:
def __init__(self, *, _id=None):
self._id = _id
if self._id is None:
self._id = self._id_counter[0]
self._id, = self._id_counter
self._id_counter[0] -= 1

@property
Expand Down Expand Up @@ -489,7 +489,7 @@ def new_space_id():


def space_id(space):
"""Return the unique :class:`int` ID associated with a space.
"""Return a unique :class:`int` ID associated with a space.
:arg space: The space.
:returns: The unique :class:`int` ID.
Expand Down Expand Up @@ -884,7 +884,7 @@ def new_var_id():


def var_id(x):
"""Return the unique :class:`int` ID associated with a variable.
"""Return a unique :class:`int` ID associated with a variable.
Note that two variables share the same ID if they represent the same
symbolic variable -- for example if one variable represents both a variable
Expand Down
61 changes: 50 additions & 11 deletions tlm_adjoint/tangent_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
# -*- coding: utf-8 -*-

from .interface import (
check_space_types, is_var, var_id, var_name, var_new_tangent_linear)
check_space_types, is_var, var_id, var_is_replacement, var_name,
var_new_tangent_linear)

from .alias import gc_disabled
from .markers import ControlsMarker, FunctionalMarker

from collections import defaultdict
Expand All @@ -28,10 +30,16 @@ def tlm_key(M, dM):
else:
dM = tuple(dM)

if any(map(var_is_replacement, M)):
raise ValueError("Invalid tangent-linear")
if any(map(var_is_replacement, dM)):
raise ValueError("Invalid tangent-linear")

if len(set(M)) != len(M):
raise ValueError("Invalid tangent-linear")
if len(M) != len(dM):
raise ValueError("Invalid tangent-linear")

for m, dm in zip(M, dM):
check_space_types(m, dm)

Expand Down Expand Up @@ -165,9 +173,26 @@ class TangentLinearMap:
direction defined by `dM`.
"""

_id_counter = [0]

def __init__(self, M, dM):
(M, dM), _ = tlm_key(M, dM)

self._id, = self._id_counter
self._id_counter[0] += 1

self._X = weakref.WeakValueDictionary()

@gc_disabled
def weakref_finalize(X, tlm_map_id):
for x_id in sorted(tuple(X)):
x = X.get(x_id, None)
if x is not None:
getattr(x, "_tlm_adjoint__tangent_linears", {}).pop(tlm_map_id, None) # noqa: E501

weakref.finalize(self, weakref_finalize,
self._X, self._id)

if len(M) == 1:
self._name_suffix = \
"_tlm(%s,%s)" % (var_name(M[0]),
Expand All @@ -180,32 +205,46 @@ def __init__(self, M, dM):
assert len(M) == len(dM)
for m, dm in zip(M, dM):
if not hasattr(m, "_tlm_adjoint__tangent_linears"):
m._tlm_adjoint__tangent_linears = weakref.WeakKeyDictionary()
self._X[var_id(m)] = m
m._tlm_adjoint__tangent_linears = {}
# Do not set _tlm_adjoint__tlm_root_id, as dm cannot appear as the
# solution to an Equation
m._tlm_adjoint__tangent_linears[self] = dm
m._tlm_adjoint__tangent_linears[self.id] = dm

def __contains__(self, x):
if hasattr(x, "_tlm_adjoint__tangent_linears"):
return self in x._tlm_adjoint__tangent_linears
else:
return False
if not is_var(x):
raise TypeError("x must be a variable")
if var_is_replacement(x):
raise ValueError("x cannot be a replacement")

return self.id in getattr(x, "_tlm_adjoint__tangent_linears", {})

def __getitem__(self, x):
if not is_var(x):
raise TypeError("x must be a variable")
if var_is_replacement(x):
raise ValueError("x cannot be a replacement")

if not hasattr(x, "_tlm_adjoint__tangent_linears"):
x._tlm_adjoint__tangent_linears = weakref.WeakKeyDictionary()
if self not in x._tlm_adjoint__tangent_linears:
self._X[var_id(x)] = x
x._tlm_adjoint__tangent_linears = {}
if self.id not in x._tlm_adjoint__tangent_linears:
tau_x = var_new_tangent_linear(
x, name=f"{var_name(x):s}{self._name_suffix:s}")
if tau_x is not None:
tau_x._tlm_adjoint__tlm_root_id = getattr(
x, "_tlm_adjoint__tlm_root_id", var_id(x))
x._tlm_adjoint__tangent_linears[self] = tau_x
x._tlm_adjoint__tangent_linears[self.id] = tau_x

return x._tlm_adjoint__tangent_linears[self.id]

@property
def id(self):
"""A unique :class:`int` ID associated with this
:class:`.TangentLinearMap`.
"""

return x._tlm_adjoint__tangent_linears[self]
return self._id


def J_tangent_linears(Js, blocks, *, max_adjoint_degree=None):
Expand Down
12 changes: 8 additions & 4 deletions tlm_adjoint/tlm_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def finalize_callback(to_drop_references, finalizes):
if MPI is not None:
self._id_counter[0] = self._comm.allreduce(
self._id_counter[0], op=MPI.MAX)
self._id = self._id_counter[0]
self._id, = self._id_counter
self._id_counter[0] += 1

self.reset(cp_method=cp_method, cp_parameters=cp_parameters)
Expand Down Expand Up @@ -514,11 +514,15 @@ def var_tlm(self, x, *args):
"""Return a tangent-linear variable.
:arg x: A variable whose tangent-linear variable should be returned.
Cannot not be a replacement.
:arg args: Identifies the tangent-linear. See
:meth:`.EquationManager.configure_tlm`.
:returns: The tangent-linear variable.
"""

if var_is_replacement(x):
raise ValueError("x cannot be a replacement")

tau = x
for _, key in map(lambda arg: tlm_key(*arg), args):
tau = self._tlm_map[key][tau]
Expand Down Expand Up @@ -685,7 +689,7 @@ def _tangent_linear(self, eq, M, dM):
if not X_ids.isdisjoint(set(key[1])):
raise ValueError("Invalid tangent-linear direction")

eq_id = eq.id()
eq_id = eq.id
eq_tlm_eqs = self._tlm_eqs.setdefault(eq_id, {})

tlm_map = self._tlm_map[key]
Expand Down Expand Up @@ -714,7 +718,7 @@ def _tangent_linear(self, eq, M, dM):
def _add_equation_finalizes(self, eq):
for referrer in eq.referrers():
assert not isinstance(referrer, WeakAlias)
referrer_id = referrer.id()
referrer_id = referrer.id
if referrer_id not in self._finalizes:
@gc_disabled
def finalize_callback(self_ref, referrer_alias, referrer_id):
Expand All @@ -739,7 +743,7 @@ def drop_references(self):
referrer = self._to_drop_references.pop()
referrer._drop_references()
if isinstance(referrer, Equation):
referrer_id = referrer.id()
referrer_id = referrer.id
if referrer_id in self._tlm_eqs:
del self._tlm_eqs[referrer_id]

Expand Down

0 comments on commit d5492fe

Please sign in to comment.