From b84c9f6201c240a4fd2b3ac01b91829c5456900e Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Mon, 16 Oct 2023 09:46:07 +0100 Subject: [PATCH 1/3] FEniCS/Firedrake backends: Move/copy _code_generator code --- docs/source/conf.py | 10 +- setup.py | 1 - tlm_adjoint/_code_generator/__init__.py | 1 - tlm_adjoint/fenics/__init__.py | 31 - .../block_system.py | 0 .../{_code_generator => fenics}/caches.py | 0 .../{_code_generator => fenics}/equations.py | 0 .../{_code_generator => fenics}/functions.py | 0 .../hessian_system.py | 0 tlm_adjoint/firedrake/__init__.py | 31 - tlm_adjoint/firedrake/block_system.py | 1322 +++++++++++++++++ tlm_adjoint/firedrake/caches.py | 476 ++++++ tlm_adjoint/firedrake/equations.py | 952 ++++++++++++ tlm_adjoint/firedrake/functions.py | 761 ++++++++++ tlm_adjoint/firedrake/hessian_system.py | 433 ++++++ 15 files changed, 3946 insertions(+), 72 deletions(-) delete mode 100644 tlm_adjoint/_code_generator/__init__.py rename tlm_adjoint/{_code_generator => fenics}/block_system.py (100%) rename tlm_adjoint/{_code_generator => fenics}/caches.py (100%) rename tlm_adjoint/{_code_generator => fenics}/equations.py (100%) rename tlm_adjoint/{_code_generator => fenics}/functions.py (100%) rename tlm_adjoint/{_code_generator => fenics}/hessian_system.py (100%) create mode 100644 tlm_adjoint/firedrake/block_system.py create mode 100644 tlm_adjoint/firedrake/caches.py create mode 100644 tlm_adjoint/firedrake/equations.py create mode 100644 tlm_adjoint/firedrake/functions.py create mode 100644 tlm_adjoint/firedrake/hessian_system.py diff --git a/docs/source/conf.py b/docs/source/conf.py index a152e6ba..b1b9a099 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -7,21 +7,15 @@ autoapi_type = "python" autoapi_dirs = ["../../tlm_adjoint"] -autoapi_ignore = ["*/_code_generator/__init__.py", - "*/checkpoint_schedules/__init__.py", +autoapi_ignore = ["*/checkpoint_schedules/__init__.py", "*/fenics/__init__.py", "*/fenics/backend.py", "*/fenics/backend_code_generator_interface.py", - "*/fenics/backend_interface.py", "*/fenics/backend_overrides.py", "*/firedrake/__init__.py", "*/firedrake/backend.py", "*/firedrake/backend_code_generator_interface.py", - "*/firedrake/backend_interface.py", - "*/firedrake/backend_overrides.py", - "*/numpy/__init__.py", - "*/numpy/backend.py", - "*/numpy/backend_interface.py"] + "*/firedrake/backend_overrides.py"] autoapi_add_toctree_entry = False autoapi_options = {"private-members": False} diff --git a/setup.py b/setup.py index 27b32642..52dce169 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,6 @@ url="https://github.com/tlm-adjoint/tlm_adjoint", license="GNU LGPL version 3", packages=["tlm_adjoint", - "tlm_adjoint._code_generator", "tlm_adjoint.checkpoint_schedules", "tlm_adjoint.fenics", "tlm_adjoint.firedrake"], diff --git a/tlm_adjoint/_code_generator/__init__.py b/tlm_adjoint/_code_generator/__init__.py deleted file mode 100644 index ff7215d7..00000000 --- a/tlm_adjoint/_code_generator/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Required by Sphinx diff --git a/tlm_adjoint/fenics/__init__.py b/tlm_adjoint/fenics/__init__.py index 24d4e809..59ea3c46 100644 --- a/tlm_adjoint/fenics/__init__.py +++ b/tlm_adjoint/fenics/__init__.py @@ -1,37 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import importlib -import sys - -modules = [("backend", "tlm_adjoint.fenics"), - ("functions", "tlm_adjoint._code_generator"), - ("backend_code_generator_interface", "tlm_adjoint.fenics"), - ("caches", "tlm_adjoint._code_generator"), - ("equations", "tlm_adjoint._code_generator"), - ("backend_interface", "tlm_adjoint.fenics"), - ("backend_overrides", "tlm_adjoint.fenics"), - ("fenics_equations", "tlm_adjoint.fenics"), - ("block_system", "tlm_adjoint._code_generator"), - ("hessian_system", "tlm_adjoint._code_generator")] - -for module_name, package in modules: - if package == "tlm_adjoint._code_generator": - sys.modules[f"tlm_adjoint.fenics.{module_name:s}"] \ - = importlib.import_module(f".{module_name:s}", - package="tlm_adjoint._code_generator") - else: - assert package == "tlm_adjoint.fenics" - sys.modules[f"tlm_adjoint._code_generator.{module_name:s}"] \ - = importlib.import_module(f".{module_name:s}", - package="tlm_adjoint.fenics") - -for module_name, package in modules: - del sys.modules[f"tlm_adjoint._code_generator.{module_name:s}"] -del sys.modules["tlm_adjoint._code_generator"] - -del importlib, sys, modules, module_name, package - from .. import * # noqa: E402,F401 del adjoint, alias, cached_hessian, caches, checkpointing, \ eigendecomposition, equation, equations, fixed_point, functional, \ diff --git a/tlm_adjoint/_code_generator/block_system.py b/tlm_adjoint/fenics/block_system.py similarity index 100% rename from tlm_adjoint/_code_generator/block_system.py rename to tlm_adjoint/fenics/block_system.py diff --git a/tlm_adjoint/_code_generator/caches.py b/tlm_adjoint/fenics/caches.py similarity index 100% rename from tlm_adjoint/_code_generator/caches.py rename to tlm_adjoint/fenics/caches.py diff --git a/tlm_adjoint/_code_generator/equations.py b/tlm_adjoint/fenics/equations.py similarity index 100% rename from tlm_adjoint/_code_generator/equations.py rename to tlm_adjoint/fenics/equations.py diff --git a/tlm_adjoint/_code_generator/functions.py b/tlm_adjoint/fenics/functions.py similarity index 100% rename from tlm_adjoint/_code_generator/functions.py rename to tlm_adjoint/fenics/functions.py diff --git a/tlm_adjoint/_code_generator/hessian_system.py b/tlm_adjoint/fenics/hessian_system.py similarity index 100% rename from tlm_adjoint/_code_generator/hessian_system.py rename to tlm_adjoint/fenics/hessian_system.py diff --git a/tlm_adjoint/firedrake/__init__.py b/tlm_adjoint/firedrake/__init__.py index 920afeb6..a95cde79 100644 --- a/tlm_adjoint/firedrake/__init__.py +++ b/tlm_adjoint/firedrake/__init__.py @@ -1,37 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import importlib -import sys - -modules = [("backend", "tlm_adjoint.firedrake"), - ("functions", "tlm_adjoint._code_generator"), - ("backend_code_generator_interface", "tlm_adjoint.firedrake"), - ("caches", "tlm_adjoint._code_generator"), - ("equations", "tlm_adjoint._code_generator"), - ("backend_interface", "tlm_adjoint.firedrake"), - ("backend_overrides", "tlm_adjoint.firedrake"), - ("firedrake_equations", "tlm_adjoint.firedrake"), - ("block_system", "tlm_adjoint._code_generator"), - ("hessian_system", "tlm_adjoint._code_generator")] - -for module_name, package in modules: - if package == "tlm_adjoint._code_generator": - sys.modules[f"tlm_adjoint.firedrake.{module_name:s}"] \ - = importlib.import_module(f".{module_name:s}", - package="tlm_adjoint._code_generator") - else: - assert package == "tlm_adjoint.firedrake" - sys.modules[f"tlm_adjoint._code_generator.{module_name:s}"] \ - = importlib.import_module(f".{module_name:s}", - package="tlm_adjoint.firedrake") - -for module_name, package in modules: - del sys.modules[f"tlm_adjoint._code_generator.{module_name:s}"] -del sys.modules["tlm_adjoint._code_generator"] - -del importlib, sys, modules, module_name, package - from .. import * # noqa: E402,F401 del adjoint, alias, cached_hessian, caches, checkpointing, \ eigendecomposition, equation, equations, fixed_point, functional, \ diff --git a/tlm_adjoint/firedrake/block_system.py b/tlm_adjoint/firedrake/block_system.py new file mode 100644 index 00000000..460a6958 --- /dev/null +++ b/tlm_adjoint/firedrake/block_system.py @@ -0,0 +1,1322 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +r"""This module is used by both the FEniCS and Firedrake backends, and +implements solvers for linear systems defined in mixed spaces. + +The :class:`System` class defines the block structure of the linear system, and +solves the system using an outer Krylov solver. A custom preconditioner can be +defined via the `pc_fn` callback to :meth:`System.solve`, and this +preconditioner can itself e.g. make use of further Krylov solvers. This +provides a Python interface for custom block preconditioners. + +Given a linear problem with a potentially singular matrix :math:`A` + +.. math:: + + A u = b, + +a :class:`System` instead solves the linear problem + +.. math:: + + \left[ (I - M U (U^* M U)^{-1} U^*) A (I - V (V^* C V)^{-1} V^* C) + + M U S V^* C \right] u = (I - M U (U^* M U)^{-1} U^*) b. + +Here + + - :math:`U` is a full rank matrix whose columns span the left nullspace for + a modified system matrix :math:`\tilde{A}`. + - :math:`V` is a full rank matrix with the same number of columns as + :math:`U`, whose columns span the nullspace for :math:`\tilde{A}`. + - :math:`V^* C V` and :math:`S` are invertible matrices. + - :math:`M` is a Hermitian positive definite matrix. + +Here the left nullspace for a matrix is defined to be the nullspace for its +Hermitian transpose, and the modified system matrix :math:`\tilde{A}` is +defined + +.. math:: + + \tilde{A} = (I - M U (U^* M U)^{-1} U^*) A (I - V (V^* C V)^{-1} V^* C). + +This has two primary use cases: + + 1. Where a matrix :math:`A` and right-hand-side :math:`b` are constructed + via finite element assembly on superspaces of the test space and trial + space. The typical example is in the application of homogeneous + essential Dirichlet boundary conditions. + + 2. Where the matrix :math:`A` is singular and :math:`b` is orthogonal to + the left nullspace of :math:`A`. Typically one would then choose + :math:`U` and :math:`V` so that their columns respectively span the left + nullspace and nullspace of :math:`A`, and the :class:`System` then seeks + a solution to the original problem subject to the linear constraints + :math:`V^* C u = 0`. + +Function spaces are defined via backend function spaces, and :class:`Sequence` +objects containing backend function spaces or similar :class:`Sequence` +objects. Similarly functions are defined via backend `Function` objects, or +:class:`Sequence` objects containing backend `Function` objects or similar +:class:`Sequence` objects. This defines a basic tree structure which is useful +e.g. when defining block matrices in terms of sub-block matrices. + +Elements of the tree are accessed in a consistent order using a depth first +search. Hence e.g. + +.. code-block:: python + + ((u_0, u_1), u_2) + +and + +.. code-block:: python + + (u_0, u_1, u_2) + +where `u_0`, `u_1`, and `u_2` are backend `Function` objects, are both valid +representations of a mixed space solution. + +Code in this module is written to use only backend functionality, and does not +use tlm_adjoint interfaces. Consequently if used directly, and in combination +with other tlm_adjoint code, space type warnings may be encountered. +""" + +# This is the only import from other tlm_adjoint modules that is permitted in +# this module +from .backend import backend + +if backend == "Firedrake": + from firedrake import ( + Cofunction, Constant, DirichletBC, Function, FunctionSpace, + TestFunction, assemble) + from firedrake.functionspaceimpl import WithGeometry as FunctionSpaceBase +elif backend == "FEniCS": + from fenics import ( + Constant, DirichletBC, Function, FunctionSpace, TestFunction, assemble) + from fenics import FunctionAssigner, as_backend_type + from dolfin.cpp.function import Constant as cpp_Constant +else: + raise ImportError(f"Unexpected backend: {backend}") + +import petsc4py.PETSc as PETSc +import ufl + +from abc import ABC, abstractmethod +from collections import deque +from collections.abc import Sequence +from contextlib import contextmanager +from enum import Enum +from functools import cached_property, wraps +import logging + +__all__ = \ + [ + "MixedSpace", + "BackendMixedSpace", + + "Nullspace", + "NoneNullspace", + "ConstantNullspace", + "UnityNullspace", + "DirichletBCNullspace", + "BlockNullspace", + + "Matrix", + "PETScMatrix", + "BlockMatrix", + "form_matrix", + + "System" + ] + + +# Following naming of PyOP2 Dat.vec_context access types +class Access(Enum): + RW = "RW" + READ = "READ" + WRITE = "WRITE" + + +def iter_sub(iterable, *, expand=None): + if expand is None: + def expand(e): + return e + + q = deque(map(expand, iterable)) + while len(q) > 0: + e = q.popleft() + if isinstance(e, Sequence) and not isinstance(e, str): + q.extendleft(map(expand, reversed(e))) + else: + yield e + + +def zip_sub(*iterables): + iterators = map(iter_sub, iterables) + yield from zip(*iterators) + + for iterator in iterators: + try: + next(iterator) + raise ValueError("Non-equal lengths") + except StopIteration: + pass + + +def tuple_sub(iterable, sequence): + iterator = iter_sub(iterable) + + def tuple_sub(iterator, value): + if isinstance(value, Sequence) and not isinstance(value, str): + return tuple(tuple_sub(iterator, e) for e in value) + return next(iterator) + + t = tuple_sub(iterator, sequence) + + try: + next(iterator) + raise ValueError("Non-equal lengths") + except StopIteration: + pass + + return t + + +class MixedSpace(ABC): + """Used to map between mixed and split versions of spaces. + + This class defines three representations for the space: + + 1. As a 'mixed space': A single function space defined using a + :class:`ufl.MixedElement`. + 2. As a 'split space': A tree defining the mixed space. Stored using + backend function space and :class:`tuple` objects, each + corresponding to a node in the tree. Function spaces correspond to + leaf nodes, and :class:`tuple` objects to other nodes in the tree. + 3. As a 'flattened space': A :class:`Sequence` containing leaf nodes of + the split space with an ordering determined using a depth first + search. + + This allows, for example, the construction: + + .. code-block:: python + + u_0 = Function(space_0, name='u_0') + u_1 = Function(space_1, name='u_1') + u_2 = Function(space_2, name='u_2') + + mixed_space = BackendMixedSpace(((space_0, space_1), space_2)) + u_fn = mixed_space.new_mixed() + + and then data can be copied to the function in the mixed space via + + .. code-block:: python + + mixed_space.split_to_mixed(u_fn, ((u_0, u_1), u_2)) + + and from the function in the mixed space via + + .. code-block:: python + + mixed_space.mixed_to_split(((u_0, u_1), u_2), u_fn) + + :arg spaces: The split space. + """ + + def __init__(self, spaces): + if isinstance(spaces, Sequence): + spaces = tuple(spaces) + else: + spaces = (spaces,) + spaces = tuple_sub(spaces, spaces) + flattened_spaces = tuple(iter_sub(spaces)) + + mesh = flattened_spaces[0].mesh() + for space in flattened_spaces[1:]: + if space.mesh() != mesh: + raise ValueError("Invalid mesh") + + if len(flattened_spaces) == 1: + mixed_space, = flattened_spaces + else: + mixed_element = ufl.classes.MixedElement( + *(space.ufl_element() for space in flattened_spaces)) + mixed_space = FunctionSpace(mesh, mixed_element) + + with vec(Function(mixed_space), Access.READ) as v: + n = v.getLocalSize() + N = v.getSize() + + self._mesh = mesh + self._spaces = spaces + self._flattened_spaces = flattened_spaces + self._mixed_space = mixed_space + self._sizes = (n, N) + + def mesh(self): + """ + :returns: The mesh associated with the space. + """ + + return self._mesh + + def split_space(self): + """ + :returns: The split space. + """ + + return self._spaces + + def flattened_space(self): + """ + :returns: The flattened space. + """ + + return self._flattened_spaces + + def mixed_space(self): + """ + :returns: The mixed space. + """ + + return self._mixed_space + + def new_split(self): + """ + :returns: A new function in the split space. + """ + + return tuple_sub(map(Function, self._flattened_spaces), self._spaces) + + def new_mixed(self): + """ + :returns: A new function in the mixed space. + """ + + return Function(self._mixed_space) + + @property + def sizes(self): + """ + A :class:`tuple`, `(n, N)`, where `n` is the number of process local + degrees of freedom and `N` is the number of global degrees of freedom, + each for the mixed space. + """ + + return self._sizes + + @abstractmethod + def mixed_to_split(self, u, u_fn): + """Copy data out of the mixed space representation. + + :arg u: A function in a compatible split space. + :arg u_fn: The function in the mixed space. + """ + + raise NotImplementedError + + @abstractmethod + def split_to_mixed(self, u_fn, u): + """Copy data into the mixed space representation. + + :arg u_fn: The function in the mixed space. + :arg u: A function in a compatible split space. + """ + + raise NotImplementedError + + +if backend == "Firedrake": + def mesh_comm(mesh): + return mesh.comm + + class BackendMixedSpace(MixedSpace): + def __init__(self, spaces): + if isinstance(spaces, Sequence): + spaces = tuple(spaces) + else: + spaces = (spaces,) + spaces = tuple_sub(spaces, spaces) + super().__init__( + tuple_sub((space if isinstance(space, FunctionSpaceBase) + else space.dual() for space in iter_sub(spaces)), + spaces)) + self._primal_dual_spaces = tuple(iter_sub(spaces)) + assert len(self._primal_dual_spaces) == len(self._flattened_spaces) + + def new_split(self): + u = [] + for space in self._primal_dual_spaces: + if isinstance(space, FunctionSpaceBase): + u.append(Function(space)) + else: + u.append(Cofunction(space)) + return tuple_sub(u, self._spaces) + + @staticmethod + def _iter_sub_fn(iterable): + def expand(e): + if isinstance(e, (Cofunction, Function)): + space = e.function_space() + if hasattr(space, "num_sub_spaces"): + return tuple(e.sub(i) + for i in range(space.num_sub_spaces())) + return e + + return iter_sub(iterable, expand=expand) + + def mixed_to_split(self, u, u_fn): + if len(self._flattened_spaces) == 1: + u, = tuple(iter_sub(u)) + with u_fn.dat.vec_ro as u_fn_v, u.dat.vec_wo as u_v: + u_fn_v.copy(result=u_v) + else: + for i, u_i in enumerate(self._iter_sub_fn(u)): + with u_fn.sub(i).dat.vec_ro as u_fn_i_v, u_i.dat.vec_wo as u_i_v: # noqa: E501 + u_fn_i_v.copy(result=u_i_v) + + def split_to_mixed(self, u_fn, u): + if len(self._flattened_spaces) == 1: + u, = tuple(iter_sub(u)) + with u_fn.dat.vec_wo as u_fn_v, u.dat.vec_ro as u_v: + u_v.copy(result=u_fn_v) + else: + for i, u_i in enumerate(self._iter_sub_fn(u)): + with u_fn.sub(i).dat.vec_wo as u_fn_i_v, u_i.dat.vec_ro as u_i_v: # noqa: E501 + u_i_v.copy(result=u_fn_i_v) + + def mat(a): + return a.petscmat + + @contextmanager + def vec(u, mode=Access.RW): + attribute_name = {Access.RW: "vec", + Access.READ: "vec_ro", + Access.WRITE: "vec_wo"}[mode] + with getattr(u.dat, attribute_name) as u_v: + yield u_v + + def bc_space(bc): + return bc.function_space() + + def bc_is_homogeneous(bc): + return isinstance(bc.function_arg, ufl.classes.Zero) + + def bc_domain_args(bc): + return (bc.sub_domain,) + + def apply_bcs(u, bcs): + if not isinstance(bcs, Sequence): + bcs = (bcs,) + if len(bcs) > 0 and not isinstance(u.function_space(), type(bcs[0].function_space())): # noqa: E501 + u_bc = u.riesz_representation("l2") + else: + u_bc = u + for bc in bcs: + bc.apply(u_bc) +elif backend == "FEniCS": + def mesh_comm(mesh): + return mesh.mpi_comm() + + class BackendMixedSpace(MixedSpace): + @cached_property + def _mixed_to_split_assigner(self): + return FunctionAssigner(list(self._flattened_spaces), + self._mixed_space) + + @cached_property + def _split_to_mixed_assigner(self): + return FunctionAssigner(self._mixed_space, + list(self._flattened_spaces)) + + def mixed_to_split(self, u, u_fn): + if len(self._flattened_spaces) == 1: + u, = tuple(iter_sub(u)) + u.assign(u_fn) + else: + self._mixed_to_split_assigner.assign(list(iter_sub(u)), + u_fn) + + def split_to_mixed(self, u_fn, u): + if len(self._flattened_spaces) == 1: + u, = tuple(iter_sub(u)) + u_fn.assign(u) + else: + self._split_to_mixed_assigner.assign(u_fn, + list(iter_sub(u))) + + def mat(a): + matrix = as_backend_type(a).mat() + if not isinstance(matrix, PETSc.Mat): + raise RuntimeError("PETSc backend required") + return matrix + + @contextmanager + def vec(u, mode=Access.RW): + if isinstance(u, Function): + u = u.vector() + u_v = as_backend_type(u).vec() + if not isinstance(u_v, PETSc.Vec): + raise RuntimeError("PETSc backend required") + + yield u_v + + if mode in {Access.RW, Access.WRITE}: + u.update_ghost_values() + + def bc_space(bc): + return FunctionSpace(bc.function_space()) + + def bc_is_homogeneous(bc): + # A weaker check with FEniCS, as the constant might be modified + return (isinstance(bc.value(), cpp_Constant) + and (bc.value().values() == 0.0).all()) + + def bc_domain_args(bc): + return (bc.sub_domain, bc.method()) + + def apply_bcs(u, bcs): + if not isinstance(bcs, Sequence): + bcs = (bcs,) + for bc in bcs: + bc.apply(u.vector()) +else: + raise ImportError(f"Unexpected backend: {backend}") + + +class Nullspace(ABC): + """Represents a matrix nullspace and left nullspace. + """ + + @abstractmethod + def apply_nullspace_transformation_lhs_right(self, x): + r"""Apply the nullspace transformation associated with a matrix action + on :math:`x`, + + .. math:: + + x \rightarrow (I - V (V^* C V)^{-1} V^* C) x. + + :arg x: Defines :math:`x`. + """ + + raise NotImplementedError + + @abstractmethod + def apply_nullspace_transformation_lhs_left(self, y): + r"""Apply the left nullspace transformation associated with a matrix + action, + + .. math:: + + y \rightarrow (I - M U (U^* M U)^{-1} U^*) y. + + :arg y: Defines :math:`y`. + """ + + raise NotImplementedError + + @abstractmethod + def constraint_correct_lhs(self, x, y): + r"""Add the linear constraint term to :math:`y`, + + .. math:: + + y \rightarrow y + M U S V^* C x. + + :arg x: Defines :math:`x`. + :arg y: Defines :math:`y`. + """ + + raise NotImplementedError + + @abstractmethod + def pc_constraint_correct_soln(self, u, b): + r"""Add the preconditioner linear constraint term to :math:`u`, + + .. math:: + + u \rightarrow u + V \tilde{S}^{-1} U^* b, + + with + + .. math:: + + \tilde{S}^{-1} = + \left( V^* C V \right)^{-1} + S^{-1} + \left( U^* M U \right)^{-1}. + + :arg u: Defines :math:`u`. + :arg b: Defines :math:`b`. + """ + + raise NotImplementedError + + def correct_soln(self, x): + """Correct the linear system solution so that it is orthogonal to + space spanned by the columns of :math:`V`. + + :arg x: The linear system solution, to be corrected. + """ + + self.apply_nullspace_transformation_lhs_right(x) + + def pre_mult_correct_lhs(self, x): + """Apply the pre-left-multiplication nullspace transformation. + + :arg x: Defines the vector on which the matrix action is computed. + """ + + self.apply_nullspace_transformation_lhs_right(x) + + def post_mult_correct_lhs(self, x, y): + """Apply the post-left-multiplication nullspace transformation, and add + the linear constraint term. + + :arg x: Defines the vector on which the matrix action is computed, and + used to add the linear constraint term. If `None` is supplied then + the linear constraint term is not added. + :arg y: Defines the result of the matrix action on `x`. + """ + + self.apply_nullspace_transformation_lhs_left(y) + if x is not None: + self.constraint_correct_lhs(x, y) + + def correct_rhs(self, b): + """Correct the linear system right-hand-side so that it is orthogonal + to the space spanned by the columns of :math:`U`. + + :arg b: The linear system right-hand-side, to be corrected. + """ + + self.apply_nullspace_transformation_lhs_left(b) + + def pc_pre_mult_correct(self, b): + """Apply the pre-preconditioner-application nullspace transformation. + + :arg b: Defines the vector on which the preconditioner action is + computed. + """ + + self.apply_nullspace_transformation_lhs_left(b) + + def pc_post_mult_correct(self, u, b): + """Apply the post-preconditioner-application left nullspace + transformation, and add the linear constraint term. + + :arg u: Defines the result of the preconditioner action on `b`. + :arg b: Defines the vector on which the preconditioner action is + computed, and used to add the linear constraint term. If `None` is + supplied then the linear constraint term is not added. + """ + + self.apply_nullspace_transformation_lhs_right(u) + if b is not None: + self.pc_constraint_correct_soln(u, b) + + +class NoneNullspace(Nullspace): + """An empty nullspace and left nullspace. + """ + + def apply_nullspace_transformation_lhs_right(self, x): + pass + + def apply_nullspace_transformation_lhs_left(self, y): + pass + + def constraint_correct_lhs(self, x, y): + pass + + def pc_constraint_correct_soln(self, u, b): + pass + + +class ConstantNullspace(Nullspace): + r"""A nullspace and left nullspace spanned by the vector of ones. + + Here :math:`V = U`, :math:`U` is a single column matrix whose elements are + ones, :math:`C = M`, and :math:`M` is an identity matrix. + + :arg alpha: Defines the linear constraint matrix :math:`S = \left( \alpha / + N \right)` where :math:`N` is the length of the vector of ones. + """ + + def __init__(self, *, alpha=1.0): + super().__init__() + self._alpha = alpha + + @staticmethod + def _correct(x, y, *, alpha=1.0): + with vec(x, Access.READ) as x_v: + x_sum = x_v.sum() + N = x_v.getSize() + + with vec(y) as y_v: + y_v.shift(alpha * x_sum / float(N)) + + def apply_nullspace_transformation_lhs_right(self, x): + self._correct(x, x, alpha=-1.0) + + def apply_nullspace_transformation_lhs_left(self, y): + self._correct(y, y, alpha=-1.0) + + def constraint_correct_lhs(self, x, y): + self._correct(x, y, alpha=self._alpha) + + def pc_constraint_correct_soln(self, u, b): + self._correct(b, u, alpha=1.0 / self._alpha) + + +class UnityNullspace(Nullspace): + r"""A nullspace and left nullspace defined by the unity-valued function. + + Here :math:`V = U`, :math:`U` is a single column matrix containing the + degree-of-freedom vector for the unity-valued function, :math:`C = M`, + and :math:`M` is the mass matrix. + + :arg space: A scalar-valued function space containing the unity-valued + function. + :arg alpha: Defines the linear constraint matrix :math:`S = \alpha \left( + U^* M U \right)^{-1}`. + """ + + def __init__(self, space, *, alpha=1.0): + U = Function(space, name="U") + U.assign(Constant(1.0)) + MU = assemble(ufl.inner(U, TestFunction(space)) * ufl.dx) + UMU = assemble(ufl.inner(U, U) * ufl.dx) + + self._alpha = alpha + self._U = U + self._MU = MU + self._UMU = UMU + + @staticmethod + def _correct(x, y, u, v, *, alpha=1.0): + with vec(x, Access.READ) as x_v, vec(u, Access.READ) as u_v: + u_x = x_v.dot(u_v) + + with vec(y) as y_v, vec(v, Access.READ) as v_v: + y_v.axpy(alpha * u_x, v_v) + + def apply_nullspace_transformation_lhs_right(self, x): + self._correct( + x, x, self._MU, self._U, alpha=-1.0 / self._UMU) + + def apply_nullspace_transformation_lhs_left(self, y): + self._correct( + y, y, self._U, self._MU, alpha=-1.0 / self._UMU) + + def constraint_correct_lhs(self, x, y): + self._correct( + x, y, self._MU, self._MU, alpha=self._alpha / self._UMU) + + def pc_constraint_correct_soln(self, u, b): + self._correct( + b, u, self._U, self._U, alpha=1.0 / (self._alpha * self._UMU)) + + +class DirichletBCNullspace(Nullspace): + r"""A nullspace and left nullspace associated with homogeneous Dirichlet + boundary conditions. + + Here :math:`V = U`, :math:`U` is a zero-one matrix with exactly one + non-zero per column corresponding to one boundary condition + degree-of-freedom, :math:`C = M`, and :math:`M` is an identity matrix. + + :arg bcs: The Dirichlet boundary conditions. + :arg alpha: Defines the linear constraint matrix :math:`S = \alpha M`. + """ + + def __init__(self, bcs, *, alpha=1.0): + if isinstance(bcs, Sequence): + bcs = tuple(bcs) + else: + bcs = (bcs,) + + space = bc_space(bcs[0]) + for bc in bcs: + if bc_space(bc) != space: + raise ValueError("Invalid space") + if not bc_is_homogeneous(bc): + raise ValueError("Homogeneous boundary conditions required") + + super().__init__() + self._bcs = bcs + self._alpha = alpha + self._c = Function(space) + + def apply_nullspace_transformation_lhs_right(self, x): + apply_bcs(x, self._bcs) + + def apply_nullspace_transformation_lhs_left(self, y): + apply_bcs(y, self._bcs) + + def _constraint_correct_lhs(self, x, y, *, alpha=1.0): + with vec(self._c, Access.WRITE) as c_v: + c_v.zeroEntries() + + apply_bcs(self._c, + tuple(DirichletBC(x.function_space(), x, *bc_domain_args(bc)) + for bc in self._bcs)) + + with vec(self._c, Access.READ) as c_v, vec(y) as y_v: + y_v.axpy(alpha, c_v) + + def constraint_correct_lhs(self, x, y): + self._constraint_correct_lhs(x, y, alpha=self._alpha) + + def pc_constraint_correct_soln(self, u, b): + self._constraint_correct_lhs(b, u, alpha=1.0 / self._alpha) + + +class BlockNullspace(Nullspace): + """Nullspaces for a mixed space. + + :arg nullspaces: A :class:`Nullspace` or a :class:`Sequence` of + :class:`Nullspace` objects defining the nullspace. `None` indicates a + :class:`NoneNullspace`. + """ + + def __init__(self, nullspaces): + if not isinstance(nullspaces, Sequence): + nullspaces = (nullspaces,) + + nullspaces = list(nullspaces) + for i, nullspace in enumerate(nullspaces): + if nullspace is None: + nullspaces[i] = NoneNullspace() + nullspaces = tuple(nullspaces) + + super().__init__() + self._nullspaces = nullspaces + + def __new__(cls, nullspaces, *args, **kwargs): + if not isinstance(nullspaces, Sequence): + nullspaces = (nullspaces,) + for nullspace in nullspaces: + if nullspace is not None \ + and not isinstance(nullspace, NoneNullspace): + break + else: + return NoneNullspace() + return super().__new__(cls) + + def __getitem__(self, key): + return self._nullspaces[key] + + def __iter__(self): + yield from self._nullspaces + + def __len__(self): + return len(self._nullspaces) + + def apply_nullspace_transformation_lhs_right(self, x): + assert len(self._nullspaces) == len(x) + for nullspace, x_i in zip(self._nullspaces, x): + nullspace.apply_nullspace_transformation_lhs_right(x_i) + + def apply_nullspace_transformation_lhs_left(self, y): + assert len(self._nullspaces) == len(y) + for nullspace, y_i in zip(self._nullspaces, y): + nullspace.apply_nullspace_transformation_lhs_left(y_i) + + def constraint_correct_lhs(self, x, y): + assert len(self._nullspaces) == len(x) + assert len(self._nullspaces) == len(y) + for nullspace, x_i, y_i in zip(self._nullspaces, x, y): + nullspace.constraint_correct_lhs(x_i, y_i) + + def pc_constraint_correct_soln(self, u, b): + assert len(self._nullspaces) == len(u) + assert len(self._nullspaces) == len(b) + for nullspace, u_i, b_i in zip(self._nullspaces, u, b): + nullspace.pc_constraint_correct_soln(u_i, b_i) + + +class Matrix(ABC): + r"""Represents a matrix :math:`A` mapping :math:`V \rightarrow W`. + + Note that :math:`V` and :math:`W` need not correspond directly to discrete + function spaces as defined by `arg_space` and `action_space`, but may + instead e.g. be defined via one or more antidual spaces. + + :arg arg_space: Defines the space `V`. + :arg action_space: Defines the space `W`. + """ + + def __init__(self, arg_space, action_space): + if isinstance(arg_space, Sequence): + arg_space = tuple(arg_space) + arg_space = tuple_sub(arg_space, arg_space) + if isinstance(action_space, Sequence): + action_space = tuple(action_space) + action_space = tuple_sub(action_space, action_space) + + self._arg_space = arg_space + self._action_space = action_space + + def arg_space(self): + """ + :returns: The space defining :math:`V`. + """ + + return self._arg_space + + def action_space(self): + """ + :returns: The space defining :math:`W`. + """ + + return self._action_space + + @abstractmethod + def mult_add(self, x, y): + """Add :math:`A x` to :math:`y`. + + :arg x: Defines :math:`x`. Should not be modified. + :arg y: Defines :math:`y`. + """ + + raise NotImplementedError + + +class PETScMatrix(Matrix): + r"""A :class:`Matrix` associated with a PETSc matrix :math:`A` mapping + :math:`V \rightarrow W`. + + :arg arg_space: Defines the space `V`. + :arg action_space: Defines the space `W`. + :arg a: The PETSc matrix. + """ + + def __init__(self, arg_space, action_space, a): + super().__init__(arg_space, action_space) + self._matrix = a + + def mult_add(self, x, y): + matrix = mat(self._matrix) + with vec(x, Access.READ) as x_v, vec(y) as y_v: + matrix.multAdd(x_v, y_v, y_v) + + +def form_matrix(a, *args, **kwargs): + """Construct a :class:`PETScMatrix` associated with a given sesquilinear + form. + + :arg a: A :class:`ufl.Form` defining the sesquilinear form. + :returns: The :class:`PETScMatrix`. + + Remaining arguments are passed to the backend :func:`assemble`. + """ + + test, trial = a.arguments() + assert test.number() < trial.number() + + return PETScMatrix( + trial.function_space(), test.function_space(), + assemble(a, *args, **kwargs)) + + +class BlockMatrix(Matrix): + r"""A matrix :math:`A` mapping :math:`V \rightarrow W`, where :math:`V` and + :math:`W` are defined by mixed spaces. + + :arg arg_spaces: Defines the space `V`. + :arg action_spaces: Defines the space `W`. + :arg block: A :class:`Mapping` defining the blocks of the matrix. Items are + `((i, j), block)` defining a :class:`ufl.Form` or :class:`Matrix` for + the block in row `i` and column `j`. A value for `block` of `None` + indicates a zero block. + """ + + def __init__(self, arg_spaces, action_spaces, blocks=None): + if not isinstance(blocks, BlockMatrix) \ + and isinstance(blocks, (Matrix, ufl.classes.Form)): + blocks = {(0, 0): blocks} + + super().__init__(arg_spaces, action_spaces) + self._blocks = {} + + if blocks is not None: + self.update(blocks) + + def __contains__(self, key): + i, j = key + return (i, j) in self._blocks + + def __iter__(self): + yield from self.keys() + + def __getitem__(self, key): + i, j = key + return self._blocks[(i, j)] + + def __setitem__(self, key, value): + i, j = key + if value is None: + self.pop((i, j), None) + else: + if isinstance(value, ufl.classes.Form): + value = form_matrix(value) + if value.arg_space() != self._arg_space[j]: + raise ValueError("Invalid space") + if value.action_space() != self._action_space[i]: + raise ValueError("Invalid space") + self._blocks[(i, j)] = value + + def __delitem__(self, key): + i, j = key + del self._blocks[(i, j)] + + def __len__(self): + return len(self._blocks) + + def keys(self): + yield from sorted(self._blocks) + + def values(self): + for (i, j) in self: + yield self[(i, j)] + + def items(self): + yield from zip(self.keys(), self.values()) + + def update(self, other): + for (i, j), block in other.items(): + self[(i, j)] = block + + def pop(self, key, *args, **kwargs): + i, j = key + return self._blocks.pop((i, j), *args, **kwargs) + + def mult_add(self, x, y): + for (i, j), block in self.items(): + block.mult_add(x[j], y[i]) + + +class PETScInterface: + def __init__(self, arg_space, action_space, nullspace): + self._arg_space = arg_space + self._action_space = action_space + self._nullspace = nullspace + + self._x = arg_space.new_split() + self._y = action_space.new_split() + + if len(arg_space.flattened_space()) == 1: + self._x_fn, = tuple(iter_sub(self._x)) + else: + self._x_fn = arg_space.new_mixed() + if len(action_space.flattened_space()) == 1: + self._y_fn, = tuple(iter_sub(self._y)) + else: + self._y_fn = action_space.new_mixed() + + if isinstance(self._nullspace, NoneNullspace): + self._x_c = self._x + else: + self._x_c = arg_space.new_split() + + def _pre_mult(self, x_petsc): + with vec(self._x_fn, Access.WRITE) as x_v: + # assert x_petsc.getSizes() == x_v.getSizes() + x_petsc.copy(result=x_v) + if len(self._arg_space.flattened_space()) != 1: + self._arg_space.mixed_to_split(self._x, self._x_fn) + + if not isinstance(self._nullspace, NoneNullspace): + for x_i, x_c_i in zip_sub(self._x, self._x_c): + with vec(x_c_i, Access.WRITE) as x_c_i_v, vec(x_i, Access.READ) as x_i_v: # noqa: E501 + x_i_v.copy(result=x_c_i_v) + + for y_i in iter_sub(self._y): + with vec(y_i, Access.WRITE) as y_i_v: + y_i_v.zeroEntries() + + def _post_mult(self, y_petsc): + if len(self._action_space.flattened_space()) != 1: + self._action_space.split_to_mixed(self._y_fn, self._y) + + with vec(self._y_fn, Access.READ) as y_v: + assert y_petsc.getSizes() == y_v.getSizes() + y_v.copy(result=y_petsc) + + +class SystemMatrix(PETScInterface): + def __init__(self, arg_space, action_space, matrix, nullspace): + if matrix.arg_space() != arg_space.split_space(): + raise ValueError("Invalid space") + if matrix.action_space() != action_space.split_space(): + raise ValueError("Invalid space") + + super().__init__(arg_space, action_space, nullspace) + self._matrix = matrix + + def mult(self, A, x, y): + self._pre_mult(x) + + if not isinstance(self._nullspace, NoneNullspace): + self._nullspace.pre_mult_correct_lhs(self._x_c) + self._matrix.mult_add(self._x_c, self._y) + if not isinstance(self._nullspace, NoneNullspace): + self._nullspace.post_mult_correct_lhs(self._x, self._y) + + self._post_mult(y) + + +class Preconditioner(PETScInterface): + def __init__(self, arg_space, action_space, pc_fn, nullspace): + super().__init__(arg_space, action_space, nullspace) + self._pc_fn = pc_fn + + def apply(self, pc, x, y): + self._pre_mult(x) + + if not isinstance(self._nullspace, NoneNullspace): + self._nullspace.pc_pre_mult_correct(self._x_c) + self._pc_fn(self._y, self._x_c) + if not isinstance(self._nullspace, NoneNullspace): + self._nullspace.pc_post_mult_correct( + self._y, self._x) + + self._post_mult(y) + + +class System: + """A linear system + + .. math:: + + A u = b. + + :arg arg_spaces: Defines the space for `u`. + :arg action_spaces: Defines the space for `b`. + :arg blocks: One of + + - A :class:`Matrix` or :class:`ufl.Form` defining :math:`A`. + - A :class:`Mapping` with items `((i, j), block)` where the matrix + associated with the block in the `i` th and `j` th column is defined + by `block`. Each `block` is a :class:`Matrix` or :class:`ufl.Form`, + or `None` to indicate a zero block. + + :arg nullspaces: A :class:`Nullspace` or a :class:`Sequence` of + :class:`Nullspace` objects defining the nullspace and left nullspace of + :math:`A`. `None` indicates a :class:`NoneNullspace`. + :arg comm: MPI communicator. + """ + + def __init__(self, arg_spaces, action_spaces, blocks, *, + nullspaces=None, comm=None): + if isinstance(arg_spaces, MixedSpace): + arg_space = arg_spaces + else: + arg_space = BackendMixedSpace(arg_spaces) + arg_spaces = arg_space.split_space() + if isinstance(action_spaces, MixedSpace): + action_space = action_spaces + else: + action_space = BackendMixedSpace(action_spaces) + action_spaces = action_space.split_space() + + matrix = BlockMatrix(arg_spaces, action_spaces, blocks) + + nullspace = BlockNullspace(nullspaces) + if isinstance(nullspace, BlockNullspace): + if len(nullspace) != len(arg_spaces): + raise ValueError("Invalid space") + if len(nullspace) != len(action_spaces): + raise ValueError("Invalid space") + + if comm is None: + self._comm = mesh_comm(arg_space.mesh()) + else: + self._comm = comm + self._arg_space = arg_space + self._action_space = action_space + self._matrix = matrix + self._nullspace = nullspace + + def solve(self, u, b, *, + solver_parameters=None, pc_fn=None, + pre_callback=None, post_callback=None, + correct_initial_guess=True, correct_solution=True): + """Solve the linear system. + + :arg u: Defines the solution :math:`u`. + :arg b: Defines the right-hand-side :math:`b`. + :arg solver_parameters: A :class:`Mapping` defining outer Krylov solver + parameters. Parameters (a number of which are based on FEniCS + solver parameters) are: + + - `'linear_solver'`: The Krylov solver type, default `'fgmres'`. + - `'pc_side'`: Overrides the PETSc default preconditioning side. + - `'relative_tolerance'`: Relative tolerance. Required. + - `'absolute_tolerance'`: Absolute tolerance. Required. + - `'divergence_limit'`: Overrides the default divergence limit. + - `'maximum_iterations'`: Maximum number of iterations. Default + 1000. + - `'norm_type'`: Overrides the default convergence norm definition. + - `'nonzero_initial_guess'`: Whether to use a non-zero initial + guess, defined by the input `u`. Default `True`. + - `'gmres_restart'`: Overrides the default GMRES restart parameter. + + :arg pc_fn: Defines the application of a preconditioner. A callable + + .. code-block:: python + + def pc_fn(u, b): + + The preconditioner is applied to `b`, and the result stored in `u`. + Defaults to an identity. + :arg pre_callback: A callable accepting a single + :class:`petsc4py.PETSc.KSP` argument. Used for detailed manual + configuration. Called after all other configuration options are + set, but before the :meth:`petsc4py.PETSc.KSP.setUp` method is + called. + :arg post_callback: A callable accepting a single + :class:`petsc4py.PETSc.KSP` argument. Called after the + :meth:`petsc4py.PETSc.KSP.solve` method has been called. + :arg correct_initial_guess: Whether to apply a nullspace correction to + the initial guess. + :arg correct_solution: Whether to apply a nullspace correction to + the solution. + :returns: The number of Krylov iterations. + """ + + if solver_parameters is None: + solver_parameters = {} + + if isinstance(u, Sequence): + u = tuple(u) + else: + u = (u,) + + if pc_fn is not None: + pc_fn_u = pc_fn + + @wraps(pc_fn_u) + def pc_fn(u, b): + u, = tuple(iter_sub(u)) + return pc_fn_u(u, b) + u = tuple_sub(u, self._arg_space.split_space()) + + if isinstance(b, Sequence): + b = tuple(b) + else: + b = (b,) + + if pc_fn is not None: + pc_fn_b = pc_fn + + @wraps(pc_fn_b) + def pc_fn(u, b): + b, = tuple(iter_sub(b)) + return pc_fn_b(u, b) + b = tuple_sub(b, self._action_space.split_space()) + + if tuple(u_i.function_space() for u_i in iter_sub(u)) \ + != self._arg_space.flattened_space(): + raise ValueError("Invalid space") + for b_i, space in zip_sub(b, self._action_space.split_space()): + if b_i is not None and b_i.function_space() != space: + raise ValueError("Invalid space") + + b_c = self._action_space.new_split() + for b_c_i, b_i in zip_sub(b_c, b): + if b_i is not None: + with vec(b_c_i, Access.WRITE) as b_c_i_v, vec(b_i, Access.READ) as b_i_v: # noqa: E501 + b_i_v.copy(result=b_c_i_v) + + A = SystemMatrix(self._arg_space, self._action_space, + self._matrix, self._nullspace) + + mat_A = PETSc.Mat().createPython( + (self._action_space.sizes, self._arg_space.sizes), A, + comm=self._comm) + mat_A.setUp() + + if pc_fn is not None: + A_pc = Preconditioner(self._action_space, self._arg_space, + pc_fn, self._nullspace) + pc = PETSc.PC().createPython( + A_pc, comm=self._comm) + pc.setOperators(mat_A) + pc.setUp() + + ksp_solver = PETSc.KSP().create(comm=self._comm) + ksp_solver.setType(solver_parameters.get("linear_solver", "fgmres")) + if pc_fn is not None: + ksp_solver.setPC(pc) + if "pc_side" in solver_parameters: + ksp_solver.setPCSide(solver_parameters["pc_side"]) + ksp_solver.setOperators(mat_A) + ksp_solver.setTolerances( + rtol=solver_parameters["relative_tolerance"], + atol=solver_parameters["absolute_tolerance"], + divtol=solver_parameters.get("divergence_limit", None), + max_it=solver_parameters.get("maximum_iterations", 1000)) + ksp_solver.setInitialGuessNonzero( + solver_parameters.get("nonzero_initial_guess", True)) + ksp_solver.setNormType( + solver_parameters.get( + "norm_type", PETSc.KSP.NormType.DEFAULT)) + if "gmres_restart" in solver_parameters: + ksp_solver.setGMRESRestart(solver_parameters["gmres_restart"]) + + logger = logging.getLogger("tlm_adjoint.System") + + def monitor(ksp_solver, it, r_norm): + logger.debug(f"KSP: " + f"iteration {it:d}, " + f"residual norm {r_norm:.16e}") + + ksp_solver.setMonitor(monitor) + + if correct_initial_guess: + self._nullspace.correct_soln(u) + self._nullspace.correct_rhs(b_c) + + if len(self._arg_space.flattened_space()) == 1: + u_fn, = tuple(iter_sub(u)) + else: + u_fn = self._arg_space.new_mixed() + self._arg_space.split_to_mixed(u_fn, u) + if len(self._action_space.flattened_space()) == 1: + b_fn, = tuple(iter_sub(b_c)) + else: + b_fn = self._action_space.new_mixed() + self._action_space.split_to_mixed(b_fn, b_c) + del b_c + + if pre_callback is not None: + pre_callback(ksp_solver) + ksp_solver.setUp() + with vec(u_fn) as u_v, vec(b_fn) as b_v: + ksp_solver.solve(b_v, u_v) + if post_callback is not None: + post_callback(ksp_solver) + del b_fn + + if len(self._arg_space.flattened_space()) != 1: + self._arg_space.mixed_to_split(u, u_fn) + del u_fn + + if correct_solution: + # Not needed if the linear problem were to be solved exactly + self._nullspace.correct_soln(u) + + if ksp_solver.getConvergedReason() <= 0: + raise RuntimeError("Convergence failure") + ksp_its = ksp_solver.getIterationNumber() + + ksp_solver.destroy() + mat_A.destroy() + if pc_fn is not None: + pc.destroy() + + return ksp_its diff --git a/tlm_adjoint/firedrake/caches.py b/tlm_adjoint/firedrake/caches.py new file mode 100644 index 00000000..30339d61 --- /dev/null +++ b/tlm_adjoint/firedrake/caches.py @@ -0,0 +1,476 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +"""This module is used by both the FEniCS and Firedrake backends, and +implements finite element assembly and linear solver data caching. +""" + +from .backend import TrialFunction, backend_DirichletBC, backend_Function +from ..interface import is_var, var_id, var_is_cached, var_space +from .backend_code_generator_interface import ( + assemble, assemble_arguments, assemble_matrix, complex_mode, linear_solver, + matrix_copy, parameters_key) + +from ..caches import Cache + +from .functions import ( + ReplacementFunction, derivative, eliminate_zeros, extract_coefficients, + replaced_form) + +from collections import defaultdict +import ufl + +__all__ = \ + [ + "AssemblyCache", + "assembly_cache", + "set_assembly_cache", + + "LinearSolverCache", + "linear_solver_cache", + "set_linear_solver_cache", + ] + + +def is_cached(expr): + for c in extract_coefficients(expr): + if not is_var(c) or not var_is_cached(c): + return False + return True + + +def form_simplify_sign(form): + integrals = [] + + for integral in form.integrals(): + integrand = integral.integrand() + + integral_sign = None + while isinstance(integrand, ufl.classes.Product): + a, b = integrand.ufl_operands + if isinstance(a, ufl.classes.IntValue) and a == -1: + if integral_sign is None: + integral_sign = -1 + else: + integral_sign = -integral_sign + integrand = b + elif isinstance(b, ufl.classes.IntValue) and b == -1: + if integral_sign is None: + integral_sign = -1 + else: + integral_sign = -integral_sign + integrand = a + else: + break + if integral_sign is not None: + if integral_sign < 0: + integral = integral.reconstruct(integrand=-integrand) + else: + integral = integral.reconstruct(integrand=integrand) + + integrals.append(integral) + + return ufl.classes.Form(integrals) + + +def form_simplify_conj(form): + if complex_mode: + def expr_conj(expr): + if isinstance(expr, ufl.classes.Conj): + x, = expr.ufl_operands + return expr_simplify_conj(x) + elif isinstance(expr, ufl.classes.Sum): + return sum(map(expr_conj, expr.ufl_operands), + ufl.classes.Zero(shape=expr.ufl_shape)) + elif isinstance(expr, ufl.classes.Product): + x, y = expr.ufl_operands + return expr_conj(x) * expr_conj(y) + else: + return ufl.conj(expr) + + def expr_simplify_conj(expr): + if isinstance(expr, ufl.classes.Conj): + x, = expr.ufl_operands + return expr_conj(x) + elif isinstance(expr, ufl.classes.Sum): + return sum(map(expr_simplify_conj, expr.ufl_operands), + ufl.classes.Zero(shape=expr.ufl_shape)) + elif isinstance(expr, ufl.classes.Product): + x, y = expr.ufl_operands + return expr_simplify_conj(x) * expr_simplify_conj(y) + else: + return expr + + def integral_simplify_conj(integral): + integrand = integral.integrand() + integrand = expr_simplify_conj(integrand) + return integral.reconstruct(integrand=integrand) + + integrals = list(map(integral_simplify_conj, form.integrals())) + return ufl.classes.Form(integrals) + else: + return ufl.algorithms.remove_complex_nodes.remove_complex_nodes(form) + + +def split_arity(form, x, argument): + form_arguments = form.arguments() + arity = len(form_arguments) + if arity >= 2: + raise ValueError("Invalid form arity") + if arity == 1 and form_arguments[0].number() != 0: + raise ValueError("Invalid form argument") + if argument.number() < arity: + raise ValueError("Invalid argument") + + if x not in extract_coefficients(form): + # No dependence on x + return ufl.classes.Form([]), form + + form_derivative = derivative(form, x, argument=argument, + enable_automatic_argument=False) + form_derivative = ufl.algorithms.expand_derivatives(form_derivative) + if x in extract_coefficients(form_derivative): + # Non-linear + return ufl.classes.Form([]), form + + try: + eq_form = ufl.algorithms.expand_derivatives( + ufl.replace(form, {x: argument})) + A = ufl.algorithms.formtransformations.compute_form_with_arity( + eq_form, arity + 1) + b = ufl.algorithms.formtransformations.compute_form_with_arity( + eq_form, arity) + except ufl.UFLException: + # UFL error encountered + return ufl.classes.Form([]), form + + try: + ufl.algorithms.check_arities.check_form_arity( + A, A.arguments(), complex_mode=complex_mode) + ufl.algorithms.check_arities.check_form_arity( + b, b.arguments(), complex_mode=complex_mode) + except ufl.algorithms.check_arities.ArityMismatch: + # Arity mismatch + return ufl.classes.Form([]), form + + if not is_cached(A): + # Non-cached higher arity form + return ufl.classes.Form([]), form + + # Success + return A, b + + +def split_terms(terms, base_integral, + cached_terms=None, mat_terms=None, non_cached_terms=None): + if cached_terms is None: + cached_terms = [] + if mat_terms is None: + mat_terms = defaultdict(lambda: []) + if non_cached_terms is None: + non_cached_terms = [] + + for term in terms: + if is_cached(term): + cached_terms.append(term) + elif isinstance(term, ufl.classes.Conj): + term_conj, = term.ufl_operands + if isinstance(term_conj, ufl.classes.Sum): + split_terms( + tuple(map(ufl.conj, term_conj.ufl_operands)), + base_integral, + cached_terms, mat_terms, non_cached_terms) + elif isinstance(term_conj, ufl.classes.Product): + x, y = term_conj.ufl_operands + split_terms( + (ufl.conj(x) * ufl.conj(y),), + base_integral, + cached_terms, mat_terms, non_cached_terms) + else: + non_cached_terms.append(term) + elif isinstance(term, ufl.classes.Sum): + split_terms(term.ufl_operands, base_integral, + cached_terms, mat_terms, non_cached_terms) + elif isinstance(term, ufl.classes.Product): + x, y = term.ufl_operands + if is_cached(x): + cached_sub, mat_sub, non_cached_sub = split_terms( + (y,), base_integral) + for term in cached_sub: + cached_terms.append(x * term) + for dep_id in mat_sub: + mat_terms[dep_id].extend( + x * mat_term for mat_term in mat_sub[dep_id]) + for term in non_cached_sub: + non_cached_terms.append(x * term) + elif is_cached(y): + cached_sub, mat_sub, non_cached_sub = split_terms( + (x,), base_integral) + for term in cached_sub: + cached_terms.append(term * y) + for dep_id in mat_sub: + mat_terms[dep_id].extend( + mat_term * y for mat_term in mat_sub[dep_id]) + for term in non_cached_sub: + non_cached_terms.append(term * y) + else: + non_cached_terms.append(term) + else: + mat_dep = None + for dep in extract_coefficients(term): + if not is_cached(dep): + if isinstance(dep, (backend_Function, ReplacementFunction)) and mat_dep is None: # noqa: E501 + mat_dep = dep + else: + mat_dep = None + break + if mat_dep is None: + non_cached_terms.append(term) + else: + term_form = ufl.classes.Form( + [base_integral.reconstruct(integrand=term)]) + mat_sub, non_cached_sub = split_arity( + term_form, mat_dep, + argument=TrialFunction(var_space(mat_dep))) + mat_sub = [integral.integrand() + for integral in mat_sub.integrals()] + non_cached_sub = [integral.integrand() + for integral in non_cached_sub.integrals()] + if len(mat_sub) > 0: + mat_terms[var_id(mat_dep)].extend(mat_sub) + non_cached_terms.extend(non_cached_sub) + + return cached_terms, dict(mat_terms), non_cached_terms + + +def split_form(form): + if len(form.arguments()) != 1: + raise ValueError("Arity 1 form required") + if not complex_mode: + form = ufl.algorithms.remove_complex_nodes.remove_complex_nodes(form) + + def add_integral(integrals, base_integral, terms): + if len(terms) > 0: + integrand = sum(terms, ufl.classes.Zero()) + integral = base_integral.reconstruct(integrand=integrand) + integrals.append(integral) + + cached_integrals = [] + mat_integrals = defaultdict(lambda: []) + non_cached_integrals = [] + for integral in form.integrals(): + cached_terms, mat_terms, non_cached_terms = \ + split_terms((integral.integrand(),), integral) + add_integral(cached_integrals, integral, cached_terms) + for dep_id in mat_terms: + add_integral(mat_integrals[dep_id], integral, mat_terms[dep_id]) + add_integral(non_cached_integrals, integral, non_cached_terms) + + cached_form = ufl.classes.Form(cached_integrals) + mat_forms = {} + for dep_id in mat_integrals: + mat_forms[dep_id] = ufl.classes.Form(mat_integrals[dep_id]) + non_cached_forms = ufl.classes.Form(non_cached_integrals) + + return cached_form, mat_forms, non_cached_forms + + +def form_dependencies(form): + deps = {} + for dep in extract_coefficients(form): + if is_var(dep): + deps.setdefault(var_id(dep), dep) + return deps + + +def form_key(form): + form = replaced_form(form) + form = ufl.algorithms.expand_derivatives(form) + form = ufl.algorithms.apply_algebra_lowering.apply_algebra_lowering(form) + form = ufl.algorithms.expand_indices(form) + form = form_simplify_conj(form) + form = form_simplify_sign(form) + return form + + +def assemble_key(form, bcs, assemble_kwargs): + return (form_key(form), tuple(bcs), parameters_key(assemble_kwargs)) + + +class AssemblyCache(Cache): + """A :class:`tlm_adjoint.caches.Cache` for finite element assembly data. + """ + + def assemble(self, form, *, + bcs=None, form_compiler_parameters=None, + linear_solver_parameters=None, replace_map=None): + """Perform finite element assembly and cache the result, or return a + previously cached result. + + :arg form: The :class:`ufl.Form` to assemble. + :arg bcs: Dirichlet boundary conditions. + :arg form_compiler_parameters: Form compiler parameters. + :arg linear_solver_parameters: Linear solver parameters. Required for + assembly parameters which appear in the linear solver parameters + -- in particular the Firedrake `'mat_type'` parameter. + :arg replace_map: A :class:`Mapping` defining a map from symbolic + variables to values. + :returns: A :class:`tuple` `(value_ref, value)`, where `value` is the + result of the finite element assembly, and `value_ref` is a + :class:`tlm_adjoint.caches.CacheRef` storing a reference to + `value`. + + - For an arity zero or arity one form `value_ref` stores the + assembled value. + - For an arity two form `value_ref` is a tuple `(A, b_bc)`. `A` + is the assembled matrix, and `b_bc` is a boundary condition + right-hand-side term which should be added after assembling a + right-hand-side with homogeneous boundary conditions applied. + `b_bc` may be `None` to indicate that this term is zero. + """ + + if bcs is None: + bcs = () + elif isinstance(bcs, backend_DirichletBC): + bcs = (bcs,) + if form_compiler_parameters is None: + form_compiler_parameters = {} + if linear_solver_parameters is None: + linear_solver_parameters = {} + + form = eliminate_zeros(form, force_non_empty_form=True) + arity = len(form.arguments()) + assemble_kwargs = assemble_arguments(arity, form_compiler_parameters, + linear_solver_parameters) + key = assemble_key(form, bcs, assemble_kwargs) + + def value(): + if replace_map is None: + assemble_form = form + else: + assemble_form = ufl.replace(form, replace_map) + if arity == 0: + if len(bcs) > 0: + raise TypeError("Unexpected boundary conditions for arity " + "0 form") + b = assemble(assemble_form, **assemble_kwargs) + elif arity == 1: + b = assemble(assemble_form, **assemble_kwargs) + for bc in bcs: + bc.apply(b) + elif arity == 2: + b = assemble_matrix(assemble_form, bcs=bcs, **assemble_kwargs) + else: + raise ValueError(f"Unexpected form arity {arity:d}") + return b + + return self.add(key, value, + deps=tuple(form_dependencies(form).values())) + + +def linear_solver_key(form, bcs, linear_solver_parameters, + form_compiler_parameters): + return (form_key(form), tuple(bcs), + parameters_key(linear_solver_parameters), + parameters_key(form_compiler_parameters)) + + +class LinearSolverCache(Cache): + """A :class:`tlm_adjoint.caches.Cache` for linear solver data. + """ + + def linear_solver(self, form, *, + bcs=None, form_compiler_parameters=None, + linear_solver_parameters=None, replace_map=None, + assembly_cache=None): + """Construct a linear solver and cache the result, or return a + previously cached result. + + :arg form: An arity two :class:`ufl.Form`, defining the matrix. + :arg bcs: Dirichlet boundary conditions. + :arg form_compiler_parameters: Form compiler parameters. + :arg linear_solver_parameters: Linear solver parameters. + :arg replace_map: A :class:`Mapping` defining a map from symbolic + variables to values. + :arg assembly_cache: :class:`AssemblyCache` to use for finite element + assembly. Defaults to `assembly_cache()`. + :returns: A :class:`tuple` `(value_ref, value)`. `value` is a tuple + `(solver, A, b_bc)`, where `solver` is the linear solver, `A` is + the assembled matrix, and `b_bc` is a boundary condition + right-hand-side term which should be added after assembling a + right-hand-side with homogeneous boundary conditions applied. + `b_bc` may be `None` to indicate that this term is zero. + `value_ref` is a :class:`tlm_adjoint.caches.CacheRef` storing a + reference to `value`. + """ + + if bcs is None: + bcs = () + elif isinstance(bcs, backend_DirichletBC): + bcs = (bcs,) + if form_compiler_parameters is None: + form_compiler_parameters = {} + if linear_solver_parameters is None: + linear_solver_parameters = {} + + form = eliminate_zeros(form, force_non_empty_form=True) + key = linear_solver_key(form, bcs, linear_solver_parameters, + form_compiler_parameters) + + if assembly_cache is None: + assembly_cache = globals()["assembly_cache"]() + + def value(): + _, (A, b_bc) = assembly_cache.assemble( + form, bcs=bcs, + form_compiler_parameters=form_compiler_parameters, + linear_solver_parameters=linear_solver_parameters, + replace_map=replace_map) + solver = linear_solver(matrix_copy(A), + linear_solver_parameters) + return solver, A, b_bc + + return self.add(key, value, + deps=tuple(form_dependencies(form).values())) + + +_assembly_cache = AssemblyCache() + + +def assembly_cache(): + """ + :returns: The default :class:`AssemblyCache`. + """ + + return _assembly_cache + + +def set_assembly_cache(assembly_cache): + """Set the default :class:`AssemblyCache`. + + :arg assembly_cache: The new default :class:`AssemblyCache`. + """ + + global _assembly_cache + _assembly_cache = assembly_cache + + +_linear_solver_cache = LinearSolverCache() + + +def linear_solver_cache(): + """ + :returns: The default :class:`LinearSolverCache`. + """ + + return _linear_solver_cache + + +def set_linear_solver_cache(linear_solver_cache): + """Set the default :class:`LinearSolverCache`. + + :arg linear_solver_cache: The new default :class:`LinearSolverCache`. + """ + + global _linear_solver_cache + _linear_solver_cache = linear_solver_cache diff --git a/tlm_adjoint/firedrake/equations.py b/tlm_adjoint/firedrake/equations.py new file mode 100644 index 00000000..58e039e5 --- /dev/null +++ b/tlm_adjoint/firedrake/equations.py @@ -0,0 +1,952 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +"""This module is used by both the FEniCS and Firedrake backends, and +implements finite element calculations. In particular the +:class:`EquationSolver` class implements the solution of finite element +variational problems. +""" + +from .backend import ( + TestFunction, TrialFunction, adjoint, backend_Constant, + backend_DirichletBC, backend_Function, parameters) +from ..interface import ( + check_space_type, is_var, var_assign, var_id, var_is_scalar, var_new, + var_new_conjugate_dual, var_replacement, var_scalar_value, var_space, + var_zero) +from .backend_code_generator_interface import ( + assemble, assemble_linear_solver, copy_parameters_dict, + form_compiler_quadrature_parameters, homogenize, interpolate_expression, + matrix_multiply, process_adjoint_solver_parameters, + process_solver_parameters, rhs_addto, rhs_copy, solve, + update_parameters_dict, verify_assembly) + +from ..caches import CacheRef +from ..equation import Equation, ZeroAssignment +from ..equations import Assignment + +from .caches import assembly_cache, is_cached, linear_solver_cache, split_form +from .functions import ( + ReplacementConstant, bcs_is_cached, bcs_is_homogeneous, bcs_is_static, + derivative, eliminate_zeros, extract_coefficients) + +import numpy as np +import ufl + +__all__ = \ + [ + "Assembly", + "DirichletBCApplication", + "EquationSolver", + "ExprInterpolation", + "Projection", + "expr_new_x", + "linear_equation_new_x" + ] + + +def derivative_dependencies(expr, dep): + dexpr = derivative(expr, dep, enable_automatic_argument=False) + dexpr = ufl.algorithms.expand_derivatives(dexpr) + return extract_coefficients(dexpr) + + +def extract_dependencies(expr, *, + space_type="primal"): + deps = {} + nl_deps = {} + for dep in extract_coefficients(expr): + if is_var(dep): + deps.setdefault(var_id(dep), dep) + for nl_dep in derivative_dependencies(expr, dep): + if is_var(nl_dep): + nl_deps.setdefault(var_id(dep), dep) + nl_deps.setdefault(var_id(nl_dep), nl_dep) + + deps = {dep_id: deps[dep_id] + for dep_id in sorted(deps.keys())} + nl_deps = {nl_dep_id: nl_deps[nl_dep_id] + for nl_dep_id in sorted(nl_deps.keys())} + + assert len(set(nl_deps.keys()).difference(set(deps.keys()))) == 0 + for dep in deps.values(): + check_space_type(dep, space_type) + + return deps, nl_deps + + +def apply_rhs_bcs(b, hbcs, *, b_bc=None): + for bc in hbcs: + bc.apply(b) + if b_bc is not None: + rhs_addto(b, b_bc) + + +class ExprEquation(Equation): + def _replace_map(self, deps): + eq_deps = self.dependencies() + assert len(eq_deps) == len(deps) + return {eq_dep: dep + for eq_dep, dep in zip(eq_deps, deps) + if isinstance(eq_dep, (ufl.classes.ConstantValue, ufl.classes.Coefficient))} # noqa: E501 + + def _replace(self, expr, deps): + if deps is None: + return expr + else: + replace_map = self._replace_map(deps) + return ufl.replace(expr, replace_map) + + def _nonlinear_replace_map(self, nl_deps): + eq_nl_deps = self.nonlinear_dependencies() + assert len(eq_nl_deps) == len(nl_deps) + return {eq_nl_dep: nl_dep + for eq_nl_dep, nl_dep in zip(eq_nl_deps, nl_deps) + if isinstance(eq_nl_dep, (ufl.classes.ConstantValue, ufl.classes.Coefficient))} # noqa: E501 + + def _nonlinear_replace(self, expr, nl_deps): + replace_map = self._nonlinear_replace_map(nl_deps) + return ufl.replace(expr, replace_map) + + +class Assembly(ExprEquation): + r"""Represents assignment to the result of finite element assembly: + + .. code-block:: python + + x = assemble(rhs) + + The forward residual :math:`\mathcal{F}` is defined so that :math:`\partial + \mathcal{F} / \partial x` is the identity. + + :arg x: A variable defining the forward solution. + :arg rhs: A :class:`ufl.Form` to assemble. Should have arity 0 or 1, and + should not depend on the forward solution. + :arg form_compiler_parameters: Form compiler parameters. + :arg match_quadrature: Whether to set quadrature parameters consistently in + the forward, adjoint, and tangent-linears. Defaults to + `parameters['tlm_adjoint']['Assembly']['match_quadrature']`. + """ + + def __init__(self, x, rhs, *, + form_compiler_parameters=None, match_quadrature=None): + if form_compiler_parameters is None: + form_compiler_parameters = {} + if match_quadrature is None: + match_quadrature = parameters["tlm_adjoint"]["Assembly"]["match_quadrature"] # noqa: E501 + + rhs = ufl.classes.Form(rhs.integrals()) + + arity = len(rhs.arguments()) + if arity == 0: + check_space_type(x, "primal") + if not var_is_scalar(x): + raise ValueError("Arity 0 forms can only be assigned to " + "scalars") + elif arity == 1: + check_space_type(x, "conjugate_dual") + else: + raise ValueError("Must be an arity 0 or arity 1 form") + + deps, nl_deps = extract_dependencies(rhs) + if var_id(x) in deps: + raise ValueError("Invalid dependency") + deps, nl_deps = list(deps.values()), tuple(nl_deps.values()) + deps.insert(0, x) + + form_compiler_parameters_ = \ + copy_parameters_dict(parameters["form_compiler"]) + update_parameters_dict(form_compiler_parameters_, + form_compiler_parameters) + form_compiler_parameters = form_compiler_parameters_ + del form_compiler_parameters_ + if match_quadrature: + update_parameters_dict( + form_compiler_parameters, + form_compiler_quadrature_parameters(rhs, form_compiler_parameters)) # noqa: E501 + + super().__init__(x, deps, nl_deps=nl_deps, ic=False, adj_ic=False) + self._rhs = rhs + self._form_compiler_parameters = form_compiler_parameters + self._arity = arity + + def drop_references(self): + replace_map = {dep: var_replacement(dep) + for dep in self.dependencies() + if isinstance(dep, (ufl.classes.ConstantValue, ufl.classes.Coefficient))} # noqa: E501 + + super().drop_references() + self._rhs = ufl.replace(self._rhs, replace_map) + + def forward_solve(self, x, deps=None): + rhs = self._replace(self._rhs, deps) + + if self._arity == 0: + var_assign( + x, + assemble(rhs, form_compiler_parameters=self._form_compiler_parameters)) # noqa: E501 + elif self._arity == 1: + assemble( + rhs, form_compiler_parameters=self._form_compiler_parameters, + tensor=x) + else: + raise ValueError("Must be an arity 0 or arity 1 form") + + def adjoint_derivative_action(self, nl_deps, dep_index, adj_x): + # Derived from EquationSolver.derivative_action (see dolfin-adjoint + # reference below). Code first added 2017-12-07. + # Re-written 2018-01-28 + # Updated to adjoint only form 2018-01-29 + + eq_deps = self.dependencies() + if dep_index < 0 or dep_index >= len(eq_deps): + raise IndexError("dep_index out of bounds") + elif dep_index == 0: + return adj_x + + dep = eq_deps[dep_index] + dF = derivative(self._rhs, dep) + dF = ufl.algorithms.expand_derivatives(dF) + dF = eliminate_zeros(dF) + if dF.empty(): + return None + + dF = self._nonlinear_replace(dF, nl_deps) + if self._arity == 0: + dF = ufl.classes.Form( + [integral.reconstruct(integrand=ufl.conj(integral.integrand())) + for integral in dF.integrals()]) # dF = adjoint(dF) + dF = assemble( + dF, form_compiler_parameters=self._form_compiler_parameters) + return (-var_scalar_value(adj_x), dF) + elif self._arity == 1: + dF = ufl.action(adjoint(dF), coefficient=adj_x) + dF = assemble( + dF, form_compiler_parameters=self._form_compiler_parameters) + return (-1.0, dF) + else: + raise ValueError("Must be an arity 0 or arity 1 form") + + def adjoint_jacobian_solve(self, adj_x, nl_deps, b): + return b + + def tangent_linear(self, M, dM, tlm_map): + x = self.x() + + tlm_rhs = ufl.classes.Form([]) + for dep in self.dependencies(): + if dep != x: + tau_dep = tlm_map[dep] + if tau_dep is not None: + tlm_rhs = tlm_rhs + derivative(self._rhs, dep, argument=tau_dep) # noqa: E501 + + tlm_rhs = ufl.algorithms.expand_derivatives(tlm_rhs) + if tlm_rhs.empty(): + return ZeroAssignment(tlm_map[x]) + else: + return Assembly( + tlm_map[x], tlm_rhs, + form_compiler_parameters=self._form_compiler_parameters) + + +def homogenized_bc(bc): + if bcs_is_homogeneous(bc): + return bc + else: + hbc = homogenize(bc) + hbc._tlm_adjoint__static = bcs_is_static(bc) + hbc._tlm_adjoint__cache = bcs_is_cached(bc) + hbc._tlm_adjoint__homogeneous = True + return hbc + + +class EquationSolver(ExprEquation): + """Represents the solution of a finite element variational problem. + + Caching is based on the approach described in + + - J. R. Maddison and P. E. Farrell, 'Rapid development and adjoining of + transient finite element models', Computer Methods in Applied + Mechanics and Engineering, 276, 95--121, 2014, doi: + 10.1016/j.cma.2014.03.010 + + The arguments `eq`, `x`, `bcs`, `J`, `form_compiler_parameters`, and + `solver_parameters` are based on the interface for the FEniCS :func:`solve` + function (see e.g. FEniCS 2017.1.0). + + :arg eq: A :class:`ufl.equation.Equation` defining the finite element + variational problem. + :arg x: A backend `Function` defining the forward solution. + :arg bcs: Dirichlet boundary conditions. + :arg J: A :class:`ufl.Form` defining a Jacobian matrix approximation to use + in a non-linear forward solve. + :arg form_compiler_parameters: Form compiler parameters. + :arg solver_parameters: Linear or non-linear solver parameters. + :arg adjoint_solver_parameters: Linear solver parameters to use in an + adjoint solve. + :arg tlm_solver_parameters: Linear solver parameters to use when solving + tangent-linear problems. + :arg cache_jacobian: Whether to cache the forward Jacobian matrix and + linear solver data. Defaults to + `parameters['tlm_adjoint']['EquationSolver]['cache_jacobian']`. If + `None` then caching is autodetected. + :arg cache_adjoint_jacobian: Whether to cache the adjoint Jacobian matrix + and linear solver data. Defaults to `cache_jacobian`. + :arg cache_tlm_jacobian: Whether to cache the Jacobian matrix and linear + solver data when solving tangent-linear problems. Defaults to + `cache_jacobian`. + :arg cache_rhs_assembly: Whether to enable right-hand-side caching. If + enabled then right-hand-side terms are divided into terms which are + cached, terms which are converted into matrix multiplication by a + cached matrix, and terms which are not cached. Defaults to + `parameters['tlm_adjoint']['EquationSolver']['cache_rhs_assembly']`. + :arg match_quadrature: Whether to set quadrature parameters consistently in + the forward, adjoint, and tangent-linears. Defaults to + `parameters['tlm_adjoint']['EquationSolver']['match_quadrature']`. + :arg defer_adjoint_assembly: Whether to use 'deferred' adjoint assembly. If + adjoint assembly is deferred then initially only symbolic expressions + for adjoint right-hand-side terms are constructed. Finite element + assembly can occur later (with default form compiler parameters), when + further adjoint right-hand-side terms are available. Defaults to + `parameters['tlm_adjoint']['EquationSolver']['defer_adjoint_assembly']`. + """ + + def __init__(self, eq, x, bcs=None, *, + J=None, form_compiler_parameters=None, solver_parameters=None, + adjoint_solver_parameters=None, tlm_solver_parameters=None, + cache_jacobian=None, cache_adjoint_jacobian=None, + cache_tlm_jacobian=None, cache_rhs_assembly=None, + match_quadrature=None, defer_adjoint_assembly=None): + if bcs is None: + bcs = [] + if form_compiler_parameters is None: + form_compiler_parameters = {} + if solver_parameters is None: + solver_parameters = {} + + if isinstance(bcs, backend_DirichletBC): + bcs = (bcs,) + else: + bcs = tuple(bcs) + if cache_jacobian is None: + if not parameters["tlm_adjoint"]["EquationSolver"]["enable_jacobian_caching"]: # noqa: E501 + cache_jacobian = False + if cache_rhs_assembly is None: + cache_rhs_assembly = parameters["tlm_adjoint"]["EquationSolver"]["cache_rhs_assembly"] # noqa: E501 + if match_quadrature is None: + match_quadrature = parameters["tlm_adjoint"]["EquationSolver"]["match_quadrature"] # noqa: E501 + if defer_adjoint_assembly is None: + defer_adjoint_assembly = parameters["tlm_adjoint"]["EquationSolver"]["defer_adjoint_assembly"] # noqa: E501 + if match_quadrature and defer_adjoint_assembly: + raise ValueError("Cannot both match quadrature and defer adjoint " + "assembly") + + check_space_type(x, "primal") + + lhs, rhs = eq.lhs, eq.rhs + del eq + lhs = ufl.classes.Form(lhs.integrals()) + linear = isinstance(rhs, ufl.classes.Form) + if linear: + rhs = ufl.classes.Form(rhs.integrals()) + if J is not None: + J = ufl.classes.Form(J.integrals()) + + if linear: + if len(lhs.arguments()) != 2: + raise ValueError("Invalid left-hand-side arguments") + if rhs.arguments() != (lhs.arguments()[0],): + raise ValueError("Invalid right-hand-side arguments") + if x in extract_coefficients(lhs) \ + or x in extract_coefficients(rhs): + raise ValueError("Invalid dependency") + + F = ufl.action(lhs, coefficient=x) - rhs + nl_solve_J = None + J = lhs + else: + if len(lhs.arguments()) != 1: + raise ValueError("Invalid left-hand-side arguments") + if not isinstance(rhs, int) or rhs != 0: + raise ValueError("Invalid right-hand-side") + + F = lhs + nl_solve_J = J + J = derivative(F, x) + J = ufl.algorithms.expand_derivatives(J) + + deps, nl_deps = extract_dependencies(F) + if nl_solve_J is not None: + for dep in extract_coefficients(nl_solve_J): + if is_var(dep): + deps.setdefault(var_id(dep), dep) + + deps = list(deps.values()) + if x in deps: + deps.remove(x) + deps.insert(0, x) + nl_deps = tuple(nl_deps.values()) + + hbcs = tuple(map(homogenized_bc, bcs)) + + if cache_jacobian is None: + cache_jacobian = is_cached(J) and bcs_is_cached(bcs) + if cache_adjoint_jacobian is None: + cache_adjoint_jacobian = cache_jacobian + if cache_tlm_jacobian is None: + cache_tlm_jacobian = cache_jacobian + + (solver_parameters, linear_solver_parameters, + ic, J_ic) = process_solver_parameters(solver_parameters, linear) + + if adjoint_solver_parameters is None: + adjoint_solver_parameters = process_adjoint_solver_parameters(linear_solver_parameters) # noqa: E501 + adj_ic = J_ic + else: + (_, adjoint_solver_parameters, + adj_ic, _) = process_solver_parameters(adjoint_solver_parameters, linear=True) # noqa: E501 + + if tlm_solver_parameters is not None: + (_, tlm_solver_parameters, + _, _) = process_solver_parameters(tlm_solver_parameters, linear=True) # noqa: E501 + + form_compiler_parameters_ = copy_parameters_dict(parameters["form_compiler"]) # noqa: E501 + update_parameters_dict(form_compiler_parameters_, + form_compiler_parameters) + form_compiler_parameters = form_compiler_parameters_ + del form_compiler_parameters_ + if match_quadrature: + update_parameters_dict( + form_compiler_parameters, + form_compiler_quadrature_parameters(F, form_compiler_parameters)) # noqa: E501 + + super().__init__(x, deps, nl_deps=nl_deps, + ic=ic, adj_ic=adj_ic, adj_type="primal") + self._F = F + self._lhs = lhs + self._rhs = rhs + self._bcs = bcs + self._hbcs = hbcs + self._J = J + self._nl_solve_J = nl_solve_J + self._form_compiler_parameters = form_compiler_parameters + self._solver_parameters = solver_parameters + self._linear_solver_parameters = linear_solver_parameters + self._adjoint_solver_parameters = adjoint_solver_parameters + self._tlm_solver_parameters = tlm_solver_parameters + self._linear = linear + + self._cache_jacobian = cache_jacobian + self._cache_adjoint_jacobian = cache_adjoint_jacobian + self._cache_tlm_jacobian = cache_tlm_jacobian + self._cache_rhs_assembly = cache_rhs_assembly + self._defer_adjoint_assembly = defer_adjoint_assembly + + self._forward_J_solver = CacheRef() + self._forward_b_pa = None + + self._adjoint_dF_cache = {} + self._adjoint_action_cache = {} + + self._adjoint_J_solver = CacheRef() + + def drop_references(self): + replace_map = {dep: var_replacement(dep) + for dep in self.dependencies()} + + super().drop_references() + + self._F = ufl.replace(self._F, replace_map) + self._lhs = ufl.replace(self._lhs, replace_map) + if isinstance(self._rhs, ufl.classes.Form): + self._rhs = ufl.replace(self._rhs, replace_map) + self._J = ufl.replace(self._J, replace_map) + if self._nl_solve_J is not None: + self._nl_solve_J = ufl.replace(self._nl_solve_J, replace_map) + + if self._forward_b_pa is not None: + cached_form, mat_forms, non_cached_form = self._forward_b_pa + + if cached_form is not None: + cached_form[0] = ufl.replace(cached_form[0], replace_map) + for dep_index, (mat_form, _) in mat_forms.items(): + mat_forms[dep_index][0] = ufl.replace(mat_form, replace_map) + if non_cached_form is not None: + non_cached_form = ufl.replace(non_cached_form, replace_map) + + self._forward_b_pa = (cached_form, mat_forms, non_cached_form) + + for dep_index, dF in self._adjoint_dF_cache.items(): + if dF is not None: + self._adjoint_dF_cache[dep_index] = ufl.replace(dF, replace_map) # noqa: E501 + + def _cached_rhs(self, deps, *, b_bc=None): + eq_deps = self.dependencies() + + if self._forward_b_pa is None: + rhs = eliminate_zeros(self._rhs, force_non_empty_form=True) + cached_form, mat_forms_, non_cached_form = split_form(rhs) + + dep_indices = {var_id(dep): dep_index + for dep_index, dep in enumerate(eq_deps)} + mat_forms = {dep_indices[dep_id]: [mat_forms_[dep_id], CacheRef()] + for dep_id in mat_forms_} + del mat_forms_, dep_indices + + if non_cached_form.empty(): + non_cached_form = None + + if cached_form.empty(): + cached_form = None + else: + cached_form = [cached_form, CacheRef()] + + self._forward_b_pa = (cached_form, mat_forms, non_cached_form) + else: + cached_form, mat_forms, non_cached_form = self._forward_b_pa + + b = None + + if non_cached_form is not None: + b = assemble( + self._replace(non_cached_form, deps), + form_compiler_parameters=self._form_compiler_parameters) + + for dep_index, (mat_form, mat_cache) in mat_forms.items(): + mat_bc = mat_cache() + if mat_bc is None: + mat_forms[dep_index][1], mat_bc = assembly_cache().assemble( + mat_form, + form_compiler_parameters=self._form_compiler_parameters, + linear_solver_parameters=self._linear_solver_parameters, + replace_map=None if deps is None else self._replace_map(deps)) # noqa: E501 + mat, _ = mat_bc + dep = (eq_deps if deps is None else deps)[dep_index] + if b is None: + b = matrix_multiply(mat, dep) + else: + matrix_multiply(mat, dep, tensor=b, addto=True) + + if cached_form is not None: + cached_b = cached_form[1]() + if cached_b is None: + cached_form[1], cached_b = assembly_cache().assemble( + cached_form[0], + form_compiler_parameters=self._form_compiler_parameters, + replace_map=None if deps is None else self._replace_map(deps)) # noqa: E501 + if b is None: + b = rhs_copy(cached_b) + else: + rhs_addto(b, cached_b) + + if b is None: + b = var_new_conjugate_dual(self.x()) + + apply_rhs_bcs(b, self._hbcs, b_bc=b_bc) + return b + + def forward_solve(self, x, deps=None): + if self._linear: + if self._cache_jacobian: + # Cases 1 and 2: Linear, Jacobian cached, with or without RHS + # assembly caching + + J_solver_mat_bc = self._forward_J_solver() + if J_solver_mat_bc is None: + # Assemble and cache the Jacobian, construct and cache the + # linear solver + self._forward_J_solver, J_solver_mat_bc = \ + linear_solver_cache().linear_solver( + self._J, bcs=self._bcs, + form_compiler_parameters=self._form_compiler_parameters, # noqa: E501 + linear_solver_parameters=self._linear_solver_parameters, # noqa: E501 + replace_map=None if deps is None else self._replace_map(deps)) # noqa: E501 + J_solver, J_mat, b_bc = J_solver_mat_bc + + if self._cache_rhs_assembly: + # Assemble the RHS with RHS assembly caching + b = self._cached_rhs(deps, b_bc=b_bc) + else: + # Assemble the RHS without RHS assembly caching + b = assemble( + self._replace(self._rhs, deps), + form_compiler_parameters=self._form_compiler_parameters) # noqa: E501 + + # Add bc RHS terms + apply_rhs_bcs(b, self._hbcs, b_bc=b_bc) + else: + if self._cache_rhs_assembly: + # Case 3: Linear, Jacobian not cached, with RHS assembly + # caching + + # Construct the linear solver, assemble the Jacobian + J_solver, J_mat, b_bc = assemble_linear_solver( + self._replace(self._J, deps), bcs=self._bcs, + form_compiler_parameters=self._form_compiler_parameters, # noqa: E501 + linear_solver_parameters=self._linear_solver_parameters) # noqa: E501 + + # Assemble the RHS with RHS assembly caching + b = self._cached_rhs(deps, b_bc=b_bc) + else: + # Case 4: Linear, Jacobian not cached, without RHS assembly + # caching + + # Construct the linear solver, assemble the Jacobian and + # RHS + J_solver, J_mat, b = assemble_linear_solver( + self._replace(self._J, deps), + b_form=self._replace(self._rhs, deps), bcs=self._bcs, + form_compiler_parameters=self._form_compiler_parameters, # noqa: E501 + linear_solver_parameters=self._linear_solver_parameters) # noqa: E501 + + J_tolerance = parameters["tlm_adjoint"]["assembly_verification"]["jacobian_tolerance"] # noqa: E501 + b_tolerance = parameters["tlm_adjoint"]["assembly_verification"]["rhs_tolerance"] # noqa: E501 + if not np.isposinf(J_tolerance) or not np.isposinf(b_tolerance): + verify_assembly( + self._replace(self._J, deps), + self._replace(self._rhs, deps), + J_mat, b, self._bcs, self._form_compiler_parameters, + self._linear_solver_parameters, J_tolerance, b_tolerance) + + J_solver.solve(x, b) + else: + # Case 5: Non-linear + lhs = self._lhs + assert isinstance(self._rhs, int) and self._rhs == 0 + if self._nl_solve_J is None: + J = self._J + else: + J = self._nl_solve_J + solve(self._replace(lhs, deps) == 0, x, self._bcs, + J=self._replace(J, deps), + form_compiler_parameters=self._form_compiler_parameters, + solver_parameters=self._solver_parameters) + + def subtract_adjoint_derivative_actions(self, adj_x, nl_deps, dep_Bs): + for dep_index, dep_B in dep_Bs.items(): + if dep_index not in self._adjoint_dF_cache: + dep = self.dependencies()[dep_index] + dF = derivative(self._F, dep) + dF = ufl.algorithms.expand_derivatives(dF) + dF = eliminate_zeros(dF) + if dF.empty(): + dF = None + else: + dF = adjoint(dF) + self._adjoint_dF_cache[dep_index] = dF + dF = self._adjoint_dF_cache[dep_index] + + if dF is not None: + if dep_index not in self._adjoint_action_cache: + if self._cache_rhs_assembly \ + and isinstance(adj_x, backend_Function) \ + and is_cached(dF): + self._adjoint_action_cache[dep_index] = CacheRef() + else: + self._adjoint_action_cache[dep_index] = None + + if self._adjoint_action_cache[dep_index] is not None: + # Cached matrix action + mat_bc = self._adjoint_action_cache[dep_index]() + if mat_bc is None: + self._adjoint_action_cache[dep_index], mat_bc = \ + assembly_cache().assemble( + dF, + form_compiler_parameters=self._form_compiler_parameters, # noqa: E501 + replace_map=self._nonlinear_replace_map(nl_deps)) # noqa: E501 + mat, _ = mat_bc + dep_B.sub(matrix_multiply(mat, adj_x)) + else: + # Cached form + dF = ufl.action(self._nonlinear_replace(dF, nl_deps), + coefficient=adj_x) + if not self._defer_adjoint_assembly: + # Immediate assembly + dF = assemble(dF, form_compiler_parameters=self._form_compiler_parameters) # noqa: E501 + dep_B.sub(dF) + + # def adjoint_derivative_action(self, nl_deps, dep_index, adj_x): + # # Similar to 'RHS.derivative_action' and + # # 'RHS.second_derivative_action' in dolfin-adjoint file + # # dolfin_adjoint/adjrhs.py (see e.g. dolfin-adjoint version 2017.1.0) + # # Code first added to JRM personal repository 2016-05-22 + # # Code first added to dolfin_adjoint_custom repository 2016-06-02 + # # Re-written 2018-01-28 + + def adjoint_jacobian_solve(self, adj_x, nl_deps, b): + if adj_x is None: + adj_x = self.new_adj_x() + + if self._cache_adjoint_jacobian: + J_solver_mat_bc = self._adjoint_J_solver() + if J_solver_mat_bc is None: + self._adjoint_J_solver, J_solver_mat_bc = \ + linear_solver_cache().linear_solver( + adjoint(self._J), bcs=self._hbcs, + form_compiler_parameters=self._form_compiler_parameters, # noqa: E501 + linear_solver_parameters=self._adjoint_solver_parameters, # noqa: E501 + replace_map=self._nonlinear_replace_map(nl_deps)) + else: + J_solver_mat_bc = assemble_linear_solver( + self._nonlinear_replace(adjoint(self._J), nl_deps), + bcs=self._hbcs, + form_compiler_parameters=self._form_compiler_parameters, + linear_solver_parameters=self._adjoint_solver_parameters) + J_solver, _, _ = J_solver_mat_bc + + apply_rhs_bcs(b, self._hbcs) + J_solver.solve(adj_x, b) + return adj_x + + def tangent_linear(self, M, dM, tlm_map): + x = self.x() + + tlm_rhs = ufl.classes.Form([]) + for dep in self.dependencies(): + if dep != x: + tau_dep = tlm_map[dep] + if tau_dep is not None: + tlm_rhs = (tlm_rhs + - derivative(self._F, dep, argument=tau_dep)) + + tlm_rhs = ufl.algorithms.expand_derivatives(tlm_rhs) + if tlm_rhs.empty(): + return ZeroAssignment(tlm_map[x]) + else: + if self._tlm_solver_parameters is None: + tlm_solver_parameters = self._linear_solver_parameters + else: + tlm_solver_parameters = self._tlm_solver_parameters + return EquationSolver( + self._J == tlm_rhs, tlm_map[x], self._hbcs, + form_compiler_parameters=self._form_compiler_parameters, + solver_parameters=tlm_solver_parameters, + adjoint_solver_parameters=self._adjoint_solver_parameters, + tlm_solver_parameters=tlm_solver_parameters, + cache_jacobian=self._cache_tlm_jacobian, + cache_adjoint_jacobian=self._cache_adjoint_jacobian, + cache_tlm_jacobian=self._cache_tlm_jacobian, + cache_rhs_assembly=self._cache_rhs_assembly, + defer_adjoint_assembly=self._defer_adjoint_assembly) + + +def expr_new_x(expr, x, *, + annotate=None, tlm=None): + """If an expression depends on `x`, then record the assignment `x_old = + x`, and replace `x` with `x_old` in the expression. + + :arg expr: A :class:`ufl.core.expr.Expr`. + :arg x: Defines `x`. + :arg annotate: Whether the :class:`tlm_adjoint.tlm_adjoint.EquationManager` + should record the solution of equations. + :arg tlm: Whether tangent-linear equations should be solved. + :returns: A :class:`ufl.core.expr.Expr` with `x` replaced with `x_old`, or + `expr` if the expression does not depend on `x`. + """ + + if x in extract_coefficients(expr): + x_old = var_new(x) + Assignment(x_old, x).solve(annotate=annotate, tlm=tlm) + return ufl.replace(expr, {x: x_old}) + else: + return expr + + +def linear_equation_new_x(eq, x, *, + annotate=None, tlm=None): + """If a symbolic expression for a linear finite element variational + problem depends on the symbolic variable representing the problem solution, + then record the assignment `x_old = x`, and replace `x` with `x_old` in the + symbolic expression. + + Required for the case where a 'new' value is computed by solving a linear + finite element variational problem depending on the 'old' value. + + :arg eq: A :class:`ufl.equation.Equation` defining the finite element + variational problem. + :arg x: A backend `Function` defining the solution to the finite element + variational problem. + :arg annotate: Whether the :class:`tlm_adjoint.tlm_adjoint.EquationManager` + should record the solution of equations. + :arg tlm: Whether tangent-linear equations should be solved. + :returns: A :class:`ufl.equation.Equation` with `x` replaced with `x_old`, + or `eq` if the symbolic expression does not depend on `x`. + """ + + lhs, rhs = eq.lhs, eq.rhs + lhs_x_dep = x in extract_coefficients(lhs) + rhs_x_dep = x in extract_coefficients(rhs) + if lhs_x_dep or rhs_x_dep: + x_old = var_new(x) + Assignment(x_old, x).solve(annotate=annotate, tlm=tlm) + if lhs_x_dep: + lhs = ufl.replace(lhs, {x: x_old}) + if rhs_x_dep: + rhs = ufl.replace(rhs, {x: x_old}) + return lhs == rhs + else: + return eq + + +class Projection(EquationSolver): + """Represents the solution of a finite element variational problem + performing a projection onto the space for `x`. + + :arg x: A backend `Function` defining the forward solution. + :arg rhs: A :class:`ufl.core.expr.Expr` defining the expression to project + onto the space for `x`, or a :class:`ufl.Form` defining the + right-hand-side of the finite element variational problem. Should not + depend on `x`. + + Remaining arguments are passed to the :class:`EquationSolver` constructor. + """ + + def __init__(self, x, rhs, *args, **kwargs): + space = var_space(x) + test, trial = TestFunction(space), TrialFunction(space) + if not isinstance(rhs, ufl.classes.Form): + rhs = ufl.inner(rhs, test) * ufl.dx + super().__init__(ufl.inner(trial, test) * ufl.dx == rhs, x, + *args, **kwargs) + + +class DirichletBCApplication(Equation): + r"""Represents the application of a Dirichlet boundary condition to a zero + valued backend `Function`. Specifically, with the Firedrake backend this + represents: + + .. code-block:: python + + x.zero() + DirichletBC(x.function_space(), y, *args, **kwargs).apply(x) + + The forward residual :math:`\mathcal{F}` is defined so that :math:`\partial + \mathcal{F} / \partial x` is the identity. + + :arg x: A backend `Function`, updated by the above operations. + :arg y: A backend `Function`, defines the Dirichet boundary condition. + + Remaining arguments are passed to `DirichletBC`. + """ + + def __init__(self, x, y, *args, **kwargs): + check_space_type(x, "primal") + check_space_type(y, "primal") + + super().__init__(x, [x, y], nl_deps=[], ic=False, adj_ic=False) + self._bc_args = args + self._bc_kwargs = kwargs + + def forward_solve(self, x, deps=None): + _, y = self.dependencies() if deps is None else deps + var_zero(x) + backend_DirichletBC( + var_space(x), y, + *self._bc_args, **self._bc_kwargs).apply(x) + + def adjoint_derivative_action(self, nl_deps, dep_index, adj_x): + if dep_index == 0: + return adj_x + elif dep_index == 1: + _, y = self.dependencies() + F = var_new_conjugate_dual(y) + backend_DirichletBC( + var_space(y), adj_x, + *self._bc_args, **self._bc_kwargs).apply(F) + return (-1.0, F) + else: + raise IndexError("dep_index out of bounds") + + def adjoint_jacobian_solve(self, adj_x, nl_deps, b): + return b + + def tangent_linear(self, M, dM, tlm_map): + x, y = self.dependencies() + + tau_y = tlm_map[y] + if tau_y is None: + return ZeroAssignment(tlm_map[x]) + else: + return DirichletBCApplication( + tlm_map[x], tau_y, + *self._bc_args, **self._bc_kwargs) + + +class ExprInterpolation(ExprEquation): + r"""Represents interpolation of `rhs` onto the space for `x`. + + The forward residual :math:`\mathcal{F}` is defined so that :math:`\partial + \mathcal{F} / \partial x` is the identity. + + :arg x: The forward solution. + :arg rhs: A :class:`ufl.core.expr.Expr` defining the expression to + interpolate onto the space for `x`. Should not depend on `x`. + """ + + def __init__(self, x, rhs): + deps, nl_deps = extract_dependencies(rhs) + if var_id(x) in deps: + raise ValueError("Invalid dependency") + deps, nl_deps = list(deps.values()), tuple(nl_deps.values()) + deps.insert(0, x) + + super().__init__(x, deps, nl_deps=nl_deps, ic=False, adj_ic=False) + self._rhs = rhs + + def drop_references(self): + replace_map = {dep: var_replacement(dep) + for dep in self.dependencies()} + + super().drop_references() + self._rhs = ufl.replace(self._rhs, replace_map) + + def forward_solve(self, x, deps=None): + interpolate_expression(x, self._replace(self._rhs, deps)) + + def adjoint_derivative_action(self, nl_deps, dep_index, adj_x): + eq_deps = self.dependencies() + if dep_index < 0 or dep_index >= len(eq_deps): + raise IndexError("dep_index out of bounds") + elif dep_index == 0: + return adj_x + + dep = eq_deps[dep_index] + + F = var_new_conjugate_dual(dep) + + if isinstance(dep, (backend_Constant, ReplacementConstant)): + if len(dep.ufl_shape) > 0: + raise NotImplementedError("Case not implemented") + dF = derivative(self._rhs, dep, argument=ufl.classes.IntValue(1)) + else: + dF = derivative(self._rhs, dep) + dF = ufl.algorithms.expand_derivatives(dF) + dF = eliminate_zeros(dF) + dF = self._nonlinear_replace(dF, nl_deps) + + interpolate_expression(F, dF, adj_x=adj_x) + return (-1.0, F) + + def adjoint_jacobian_solve(self, adj_x, nl_deps, b): + return b + + def tangent_linear(self, M, dM, tlm_map): + x = self.x() + + tlm_rhs = ufl.classes.Zero(shape=x.ufl_shape) + for dep in self.dependencies(): + if dep != x: + tau_dep = tlm_map[dep] + if tau_dep is not None: + # Cannot use += as Firedrake might add to the *values* for + # tlm_rhs + tlm_rhs = (tlm_rhs + + derivative(self._rhs, dep, argument=tau_dep)) + + if isinstance(tlm_rhs, ufl.classes.Zero): + return ZeroAssignment(tlm_map[x]) + tlm_rhs = ufl.algorithms.expand_derivatives(tlm_rhs) + if isinstance(tlm_rhs, ufl.classes.Zero): + return ZeroAssignment(tlm_map[x]) + else: + return ExprInterpolation(tlm_map[x], tlm_rhs) diff --git a/tlm_adjoint/firedrake/functions.py b/tlm_adjoint/firedrake/functions.py new file mode 100644 index 00000000..43cbb66e --- /dev/null +++ b/tlm_adjoint/firedrake/functions.py @@ -0,0 +1,761 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +"""This module is used by both the FEniCS and Firedrake backends, and includes +functionality for handling :class:`ufl.Coefficient` objects and boundary +conditions. +""" + +from .backend import ( + TestFunction, TrialFunction, backend_Constant, backend_DirichletBC, + backend_ScalarType) +from ..interface import ( + DEFAULT_COMM, SpaceInterface, add_interface, comm_parent, is_var, + space_comm, var_caches, var_comm, var_dtype, var_derivative_space, var_id, + var_is_cached, var_is_replacement, var_is_static, var_linf_norm, var_name, + var_replacement, var_scalar_value, var_space, var_space_type) +from ..interface import VariableInterface as _VariableInterface + +from ..caches import Caches +from ..manager import manager_disabled +from ..overloaded_float import SymbolicFloat + +import numpy as np +import ufl +import weakref + +__all__ = \ + [ + "Constant", + "extract_coefficients", + + "Zero", + "ZeroConstant", + "eliminate_zeros", + + "Replacement", + "ReplacementConstant", + "ReplacementFunction", + + "DirichletBC", + "HomogeneousDirichletBC" + ] + + +class ConstantSpaceInterface(SpaceInterface): + def _comm(self): + return self._tlm_adjoint__space_interface_attrs["comm"] + + def _dtype(self): + return self._tlm_adjoint__space_interface_attrs["dtype"] + + def _id(self): + return self._tlm_adjoint__space_interface_attrs["id"] + + def _new(self, *, name=None, space_type="primal", static=False, + cache=None): + domain = self._tlm_adjoint__space_interface_attrs["domain"] + return Constant(name=name, domain=domain, space=self, + space_type=space_type, static=static, cache=cache) + + +class ConstantInterface(_VariableInterface): + def _space(self): + return self._tlm_adjoint__var_interface_attrs["space"] + + def _derivative_space(self): + return self._tlm_adjoint__var_interface_attrs["derivative_space"](self) + + def _space_type(self): + return self._tlm_adjoint__var_interface_attrs["space_type"] + + def _dtype(self): + return self._tlm_adjoint__var_interface_attrs["dtype"] + + def _id(self): + return self._tlm_adjoint__var_interface_attrs["id"] + + def _name(self): + return self._tlm_adjoint__var_interface_attrs["name"](self) + + def _state(self): + return self._tlm_adjoint__var_interface_attrs["state"][0] + + def _update_state(self): + self._tlm_adjoint__var_interface_attrs["state"][0] += 1 + + def _is_static(self): + return self._tlm_adjoint__var_interface_attrs["static"] + + def _is_cached(self): + return self._tlm_adjoint__var_interface_attrs["cache"] + + def _caches(self): + if "caches" not in self._tlm_adjoint__var_interface_attrs: + self._tlm_adjoint__var_interface_attrs["caches"] \ + = Caches(self) + return self._tlm_adjoint__var_interface_attrs["caches"] + + @manager_disabled() + def _zero(self): + if len(self.ufl_shape) == 0: + value = 0.0 + else: + value = np.zeros(self.ufl_shape, dtype=var_dtype(self)) + value = backend_Constant(value) + self.assign(value) + + @manager_disabled() + def _assign(self, y): + if isinstance(y, SymbolicFloat): + y = y.value + if isinstance(y, (int, np.integer, + float, np.floating, + complex, np.complexfloating)): + if len(self.ufl_shape) == 0: + value = y + else: + value = np.full(self.ufl_shape, y, dtype=var_dtype(self)) + value = backend_Constant(value) + elif isinstance(y, backend_Constant): + value = y + else: + raise TypeError(f"Unexpected type: {type(y)}") + self.assign(value) + + @manager_disabled() + def _axpy(self, alpha, x, /): + if isinstance(x, SymbolicFloat): + x = x.value + if isinstance(x, (int, np.integer, + float, np.floating, + complex, np.complexfloating)): + if len(self.ufl_shape) == 0: + value = (var_scalar_value(self) + alpha * x) + else: + value = self.values() + alpha * x + value.shape = self.ufl_shape + value = backend_Constant(value) + elif isinstance(x, backend_Constant): + if len(self.ufl_shape) == 0: + value = (var_scalar_value(self) + + alpha * var_scalar_value(x)) + else: + value = self.values() + alpha * x.values() + value.shape = self.ufl_shape + value = backend_Constant(value) + elif is_var(x): + value = (var_scalar_value(self) + + alpha * var_scalar_value(x)) + else: + raise TypeError(f"Unexpected type: {type(x)}") + self.assign(value) + + def _inner(self, y): + if isinstance(y, backend_Constant): + return y.values().conjugate().dot(self.values()) + else: + raise TypeError(f"Unexpected type: {type(y)}") + + def _sum(self): + return self.values().sum() + + def _linf_norm(self): + return abs(self.values()).max(initial=0.0) + + def _local_size(self): + comm = var_comm(self) + if comm.rank == 0: + if len(self.ufl_shape) == 0: + return 1 + else: + return np.prod(self.ufl_shape) + else: + return 0 + + def _global_size(self): + if len(self.ufl_shape) == 0: + return 1 + else: + return np.prod(self.ufl_shape) + + def _local_indices(self): + comm = var_comm(self) + if comm.rank == 0: + if len(self.ufl_shape) == 0: + return slice(0, 1) + else: + return slice(0, np.prod(self.ufl_shape)) + else: + return slice(0, 0) + + def _get_values(self): + comm = var_comm(self) + if comm.rank == 0: + values = self.values().copy() + else: + values = np.array([], dtype=var_dtype(self)) + return values + + @manager_disabled() + def _set_values(self, values): + comm = var_comm(self) + if comm.rank != 0: + values = None + values = comm.bcast(values, root=0) + if len(self.ufl_shape) == 0: + values.shape = (1,) + self.assign(values[0]) + else: + values.shape = self.ufl_shape + self.assign(backend_Constant(values)) + + def _replacement(self): + if not hasattr(self, "_tlm_adjoint__replacement"): + self._tlm_adjoint__replacement = ReplacementConstant(self) + return self._tlm_adjoint__replacement + + def _is_replacement(self): + return False + + def _is_scalar(self): + return len(self.ufl_shape) == 0 + + def _scalar_value(self): + # assert var_is_scalar(self) + return var_dtype(self)(self) + + def _is_alias(self): + return "alias" in self._tlm_adjoint__var_interface_attrs + + +def constant_value(value=None, shape=None): + if value is None: + if shape is None: + shape = () + elif shape is not None: + value_ = value + if not isinstance(value_, np.ndarray): + value_ = np.array(value_) + if value_.shape != shape: + raise ValueError("Invalid shape") + del value_ + + # Default value + if value is None: + if len(shape) == 0: + value = 0.0 + else: + value = np.zeros(shape, dtype=backend_ScalarType) + + return value + + +class Constant(backend_Constant): + """Extends the backend `Constant` class. + + :arg value: The initial value. `None` indicates a value of zero. + :arg name: A :class:`str` name. + :arg domain: The domain on which the :class:`Constant` is defined. + :arg space: The space on which the :class:`Constant` is defined. + :arg space_type: The space type for the :class:`Constant`. `'primal'`, + `'dual'`, `'conjugate'`, or `'conjugate_dual'`. + :arg shape: A :class:`tuple` of :class:`int` objects defining the shape of + the value. + :arg comm: The communicator for the :class:`Constant`. + :arg static: Defines whether the :class:`Constant` is static, meaning that + it is stored by reference in checkpointing/replay, and an associated + tangent-linear variable is zero. + :arg cache: Defines whether results involving the :class:`Constant` may be + cached. Default `static`. + + Remaining arguments are passed to the backend `Constant` constructor. + """ + + def __init__(self, value=None, *args, name=None, domain=None, space=None, + space_type="primal", shape=None, comm=None, static=False, + cache=None, **kwargs): + if space_type not in {"primal", "conjugate", "dual", "conjugate_dual"}: + raise ValueError("Invalid space type") + + if domain is None and space is not None: + domains = space.ufl_domains() + if len(domains) > 0: + domain, = domains + del domains + + # Shape initialization / checking + if space is not None: + if shape is None: + shape = space.ufl_element().value_shape() + elif shape != space.ufl_element().value_shape(): + raise ValueError("Invalid shape") + + value = constant_value(value, shape) + + # Default comm + if comm is None: + if space is None: + comm = DEFAULT_COMM + else: + comm = comm_parent(space_comm(space)) + + if cache is None: + cache = static + + super().__init__( + value, *args, name=name, domain=domain, space=space, + comm=comm, **kwargs) + self._tlm_adjoint__var_interface_attrs.d_setitem("space_type", space_type) # noqa: E501 + self._tlm_adjoint__var_interface_attrs.d_setitem("static", static) + self._tlm_adjoint__var_interface_attrs.d_setitem("cache", cache) + + def __new__(cls, value=None, *args, name=None, domain=None, + space_type="primal", shape=None, static=False, cache=None, + **kwargs): + if issubclass(cls, ufl.classes.Coefficient) or domain is None: + return object().__new__(cls) + else: + # For Firedrake + value = constant_value(value, shape) + if space_type not in {"primal", "conjugate", + "dual", "conjugate_dual"}: + raise ValueError("Invalid space type") + if cache is None: + cache = static + F = super().__new__(cls, value, domain=domain) + F.rename(name=name) + F._tlm_adjoint__var_interface_attrs.d_setitem("space_type", space_type) # noqa: E501 + F._tlm_adjoint__var_interface_attrs.d_setitem("static", static) + F._tlm_adjoint__var_interface_attrs.d_setitem("cache", cache) + return F + + +class Zero: + """Mixin for defining a zero-valued variable. Used for zero-valued + variables for which UFL zero elimination should not be applied. + """ + + def _tlm_adjoint__var_interface_assign(self, y): + raise RuntimeError("Cannot call _assign interface of Zero") + + def _tlm_adjoint__var_interface_axpy(self, alpha, x, /): + raise RuntimeError("Cannot call _axpy interface of Zero") + + def _tlm_adjoint__var_interface_set_values(self, values): + raise RuntimeError("Cannot call _set_values interface of Zero") + + def _tlm_adjoint__var_interface_update_state(self): + raise RuntimeError("Cannot call _update_state interface of Zero") + + +class ZeroConstant(Constant, Zero): + """A :class:`Constant` which is flagged as having a value of zero. + + Arguments are passed to the :class:`Constant` constructor, together with + `static=True` and `cache=True`. + """ + + def __init__(self, *, name=None, domain=None, space=None, + space_type="primal", shape=None, comm=None): + Constant.__init__( + self, name=name, domain=domain, space=space, space_type=space_type, + shape=shape, comm=comm, static=True, cache=True) + if var_linf_norm(self) != 0.0: + raise RuntimeError("ZeroConstant is not zero-valued") + + def __new__(cls, *args, shape=None, **kwargs): + return Constant.__new__( + cls, constant_value(shape=shape), *args, + shape=shape, static=True, cache=True, **kwargs) + + def assign(self, *args, **kwargs): + raise RuntimeError("Cannot call assign method of ZeroConstant") + + +def as_coefficient(x): + if isinstance(x, ufl.classes.Coefficient): + return x + + # For Firedrake + + if not isinstance(x, backend_Constant): + raise TypeError("Unexpected type") + + if not hasattr(x, "_tlm_adjoint__Coefficient"): + if is_var(x): + space = var_space(x) + else: + if len(x.ufl_shape) == 0: + element = ufl.classes.FiniteElement("R", None, 0) + elif len(x.ufl_shape) == 1: + element = ufl.classes.VectorElement("R", None, 0, + dim=x.ufl_shape[0]) + else: + element = ufl.classes.TensorElement("R", None, 0, + shape=x.ufl_shape) + space = ufl.classes.FunctionSpace(None, element) + + x._tlm_adjoint__Coefficient = ufl.classes.Coefficient(space) + + return x._tlm_adjoint__Coefficient + + +def with_coefficient(expr, x): + x_coeff = as_coefficient(x) + if x_coeff is x: + return expr, {}, {} + else: + # For Firedrake + replace_map = {x: x_coeff} + replace_map_inverse = {x_coeff: x} + return ufl.replace(expr, replace_map), replace_map, replace_map_inverse + + +def with_coefficients(expr): + if isinstance(expr, ufl.classes.Form) \ + and "_tlm_adjoint__form_with_coefficients" in expr._cache: + return expr._cache["_tlm_adjoint__form_with_coefficients"] + + if issubclass(backend_Constant, ufl.classes.Coefficient): + replace_map = {} + else: + # For Firedrake + constants = tuple(sorted(ufl.algorithms.extract_type(expr, backend_Constant), # noqa: E501 + key=lambda c: c.count())) + replace_map = dict(zip(constants, map(as_coefficient, constants))) + replace_map_inverse = {c_coeff: c + for c, c_coeff in replace_map.items()} + + expr_with_coeffs = ufl.replace(expr, replace_map) + if isinstance(expr, ufl.classes.Form): + expr._cache["_tlm_adjoint__form_with_coefficients"] = \ + (expr_with_coeffs, replace_map, replace_map_inverse) + return expr_with_coeffs, replace_map, replace_map_inverse + + +def extract_coefficients(expr): + """ + :returns: Variables on which the supplied :class:`ufl.core.expr.Expr` or + :class:`ufl.Form` depends. + """ + + if isinstance(expr, ufl.classes.Form) \ + and "_tlm_adjoint__form_coefficients" in expr._cache: + return expr._cache["_tlm_adjoint__form_coefficients"] + + if issubclass(backend_Constant, ufl.classes.Coefficient): + cls = (ufl.classes.Coefficient,) + else: + # For Firedrake + cls = (ufl.classes.Coefficient, backend_Constant) + deps = [] + for c in cls: + deps.extend(sorted(ufl.algorithms.extract_type(expr, c), + key=lambda dep: dep.count())) + + if isinstance(expr, ufl.classes.Form): + expr._cache["_tlm_adjoint__form_coefficients"] = deps + return deps + + +def derivative(expr, x, argument=None, *, + enable_automatic_argument=True): + expr_arguments = ufl.algorithms.extract_arguments(expr) + arity = len(expr_arguments) + + if argument is None and enable_automatic_argument: + Argument = {0: TestFunction, 1: TrialFunction}[arity] + argument = Argument(var_derivative_space(x)) + + for expr_argument in expr_arguments: + if expr_argument.number() >= arity: + raise ValueError("Unexpected argument") + if isinstance(argument, ufl.classes.Argument) and argument.number() < arity: # noqa: E501 + raise ValueError("Invalid argument") + + if isinstance(expr, ufl.classes.Expr): + expr, replace_map, replace_map_inverse = with_coefficient(expr, x) + else: + expr, replace_map, replace_map_inverse = with_coefficients(expr) + + if argument is not None: + argument, _, argument_replace_map_inverse = with_coefficients(argument) + replace_map_inverse.update(argument_replace_map_inverse) + + dexpr = ufl.derivative(expr, replace_map.get(x, x), argument=argument) + return ufl.replace(dexpr, replace_map_inverse) + + +def eliminate_zeros(expr, *, force_non_empty_form=False): + """Apply zero elimination for + :class:`tlm_adjoint._code_generator.functions.Zero` objects in the supplied + :class:`ufl.core.expr.Expr` or :class:`ufl.Form`. + + :arg expr: A :class:`ufl.core.expr.Expr` or :class:`ufl.Form`. + :arg force_non_empty_form: If `True` and if `expr` is a :class:`ufl.Form`, + then the returned form is guaranteed to be non-empty, and may be + assembled. + :returns: A :class:`ufl.core.expr.Expr` or :class:`ufl.Form` with zero + elimination applied. May return `expr`. + """ + + if isinstance(expr, ufl.classes.Form) \ + and "_tlm_adjoint__simplified_form" in expr._cache: + simplified_expr = expr._cache["_tlm_adjoint__simplified_form"] + else: + replace_map = {} + for c in extract_coefficients(expr): + if isinstance(c, Zero): + replace_map[c] = ufl.classes.Zero(shape=c.ufl_shape) + + if len(replace_map) == 0: + simplified_expr = expr + else: + simplified_expr = ufl.replace(expr, replace_map) + + if isinstance(expr, ufl.classes.Form): + expr._cache["_tlm_adjoint__simplified_form"] = simplified_expr + + if force_non_empty_form \ + and isinstance(simplified_expr, ufl.classes.Form) \ + and simplified_expr.empty(): + if "_tlm_adjoint__simplified_form_non_empty" in expr._cache: + simplified_expr = expr._cache["_tlm_adjoint__simplified_form_non_empty"] # noqa: E501 + else: + # Inefficient, but it is very difficult to generate a non-empty but + # zero valued form + arguments = expr.arguments() + zero = ZeroConstant() + if len(arguments) == 0: + domain, = expr.ufl_domains() + simplified_expr = zero * ufl.ds(domain) + elif len(arguments) == 1: + test, = arguments + simplified_expr = ufl.inner(zero, test[tuple(0 for _ in test.ufl_shape)]) * ufl.ds # noqa: E501 + else: + test, trial = arguments + simplified_expr = zero * ufl.inner(trial[tuple(0 for _ in trial.ufl_shape)], # noqa: E501 + test[tuple(0 for _ in test.ufl_shape)]) * ufl.ds # noqa: E501 + + if isinstance(expr, ufl.classes.Form): + expr._cache["_tlm_adjoint__simplified_form_non_empty"] = simplified_expr # noqa: E501 + + return simplified_expr + + +class DirichletBC(backend_DirichletBC): + """Extends the backend `DirichletBC`. + + :arg static: A flag that indicates that the value for the + :class:`DirichletBC` will not change, and which determines whether + calculations involving this :class:`DirichletBC` can be cached. If + `None` then autodetected from the value. + + Remaining arguments are passed to the backend `DirichletBC` constructor. + """ + + # Based on FEniCS 2019.1.0 DirichletBC API + def __init__(self, V, g, sub_domain, *args, + static=None, _homogeneous=False, **kwargs): + super().__init__(V, g, sub_domain, *args, **kwargs) + + if static is None: + for dep in extract_coefficients( + g if isinstance(g, ufl.classes.Expr) + else Constant(g, static=True)): + if not is_var(dep) or not var_is_static(dep): + static = False + break + else: + static = True + + self._tlm_adjoint__static = static + self._tlm_adjoint__cache = static + self._tlm_adjoint__homogeneous = _homogeneous + + def homogenize(self): + """Homogenize the :class:`DirichletBC`, setting its value to zero. + """ + + if self._tlm_adjoint__static: + raise RuntimeError("Cannot call homogenize method for static " + "DirichletBC") + if not self._tlm_adjoint__homogeneous: + super().homogenize() + self._tlm_adjoint__homogeneous = True + + def set_value(self, *args, **kwargs): + """Set the :class:`DirichletBC` value. + + Arguments are passed to the base class `set_value` method. + """ + + if self._tlm_adjoint__static: + raise RuntimeError("Cannot call set_value method for static " + "DirichletBC") + super().set_value(*args, **kwargs) + + +class HomogeneousDirichletBC(DirichletBC): + """A :class:`DirichletBC` whose value is zero. + + Arguments are passed to the :class:`DirichletBC` constructor, together with + `static=True`. + """ + + # Based on FEniCS 2019.1.0 DirichletBC API + def __init__(self, V, sub_domain, *args, **kwargs): + shape = V.ufl_element().value_shape() + if len(shape) == 0: + g = 0.0 + else: + g = np.zeros(shape, dtype=backend_ScalarType) + super().__init__(V, g, sub_domain, *args, static=True, + _homogeneous=True, **kwargs) + + +def bcs_is_static(bcs): + if isinstance(bcs, backend_DirichletBC): + bcs = (bcs,) + for bc in bcs: + if not getattr(bc, "_tlm_adjoint__static", False): + return False + return True + + +def bcs_is_cached(bcs): + if isinstance(bcs, backend_DirichletBC): + bcs = (bcs,) + for bc in bcs: + if not getattr(bc, "_tlm_adjoint__cache", False): + return False + return True + + +def bcs_is_homogeneous(bcs): + if isinstance(bcs, backend_DirichletBC): + bcs = (bcs,) + for bc in bcs: + if not getattr(bc, "_tlm_adjoint__homogeneous", False): + return False + return True + + +class ReplacementInterface(_VariableInterface): + def _space(self): + return self.ufl_function_space() + + def _derivative_space(self): + return self._tlm_adjoint__var_interface_attrs.get( + "derivative_space", lambda x: var_space(x))(self) + + def _space_type(self): + return self._tlm_adjoint__var_interface_attrs["space_type"] + + def _id(self): + return self._tlm_adjoint__var_interface_attrs["id"] + + def _name(self): + return self._tlm_adjoint__var_interface_attrs["name"] + + def _state(self): + return -1 + + def _is_static(self): + return self._tlm_adjoint__var_interface_attrs["static"] + + def _is_cached(self): + return self._tlm_adjoint__var_interface_attrs["cache"] + + def _caches(self): + return self._tlm_adjoint__var_interface_attrs["caches"] + + def _replacement(self): + return self + + def _is_replacement(self): + return True + + +class Replacement(ufl.classes.Coefficient): + """A :class:`ufl.Coefficient` representing a symbolic variable but with no + value. + """ + + def __init__(self, x): + space = var_space(x) + + x_domains = x.ufl_domains() + if len(x_domains) == 0: + domain = None + else: + domain, = x_domains + + super().__init__(space, count=x.count()) + self._tlm_adjoint__domain = domain + add_interface(self, ReplacementInterface, + {"id": var_id(x), "name": var_name(x), + "space": space, + "space_type": var_space_type(x), + "static": var_is_static(x), + "cache": var_is_cached(x), + "caches": var_caches(x)}) + + def ufl_domain(self): + return self._tlm_adjoint__domain + + def ufl_domains(self): + if self._tlm_adjoint__domain is None: + return () + else: + return (self._tlm_adjoint__domain,) + + +class ReplacementConstant(Replacement): + """Represents a symbolic constant, but has no value. + """ + + def __init__(self, x): + super().__init__(x) + self._tlm_adjoint__var_interface_attrs["derivative_space"] \ + = x._tlm_adjoint__var_interface_attrs["derivative_space"] + + +class ReplacementFunction(Replacement): + """Represents a symbolic backend `Function`, but has no value. + """ + + def function_space(self): + return var_space(self) + + +def replaced_form(form): + replace_map = {} + for c in extract_coefficients(form): + if is_var(c) and not var_is_replacement(c): + c_rep = var_replacement(c) + if c_rep is not c: + replace_map[c] = c_rep + return ufl.replace(form, replace_map) + + +def define_var_alias(x, parent, *, key): + if x is not parent: + if "alias" in x._tlm_adjoint__var_interface_attrs: + alias_parent, alias_key = x._tlm_adjoint__var_interface_attrs["alias"] # noqa: E501 + alias_parent = alias_parent() + if alias_parent is None or alias_parent is not parent \ + or alias_key != key: + raise ValueError("Invalid alias data") + else: + x._tlm_adjoint__var_interface_attrs["alias"] \ + = (weakref.ref(parent), key) + x._tlm_adjoint__var_interface_attrs.d_setitem( + "space_type", var_space_type(parent)) + x._tlm_adjoint__var_interface_attrs.d_setitem( + "static", var_is_static(parent)) + x._tlm_adjoint__var_interface_attrs.d_setitem( + "cache", var_is_cached(parent)) + x._tlm_adjoint__var_interface_attrs.d_setitem( + "state", parent._tlm_adjoint__var_interface_attrs["state"]) diff --git a/tlm_adjoint/firedrake/hessian_system.py b/tlm_adjoint/firedrake/hessian_system.py new file mode 100644 index 00000000..95f47e51 --- /dev/null +++ b/tlm_adjoint/firedrake/hessian_system.py @@ -0,0 +1,433 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from ..interface import ( + check_space_types, comm_dup_cached, is_var, space_comm, space_dtype, + space_new, var_assign, var_axpy, var_axpy_conjugate, var_copy, + var_copy_conjugate, var_dtype, var_inner, var_space, var_space_type) + +from ..eigendecomposition import eigendecompose +from ..manager import manager_disabled + +from .block_system import ( + BackendMixedSpace, BlockNullspace, Matrix, NoneNullspace, Preconditioner, + System, iter_sub, tuple_sub) + +from collections.abc import Sequence +import numpy as np +import petsc4py.PETSc as PETSc +import warnings + +__all__ = \ + [ + "HessianSystem", + "hessian_eigendecompose", + "B_inv_orthonormality_test", + "hessian_eigendecomposition_pc", + ] + + +class MixedSpace(BackendMixedSpace): + def __init__(self, spaces, space_types=None): + if isinstance(spaces, Sequence): + spaces = tuple(spaces) + else: + spaces = (spaces,) + spaces = tuple_sub(spaces, spaces) + + if space_types is None: + space_types = "primal" + if space_types in {"primal", "conjugate", "dual", "conjugate_dual"}: + space_types = tuple(space_types for _ in iter_sub(spaces)) + else: + space_types = tuple(iter_sub(space_types)) + + super().__init__(spaces) + if len(space_types) != len(self.flattened_space()): + raise ValueError("Invalid space types") + self._space_types = space_types + + def new_split(self): + flattened_space = self.flattened_space() + assert len(flattened_space) == len(self._space_types) + return tuple_sub((space_new(space, space_type=space_type) + for space, space_type in zip(flattened_space, self._space_types)), # noqa: E501 + self.split_space()) + + def new_mixed(self): + space_type = self._space_types[0] + return space_new(self.mixed_space(), space_type=space_type) + + +# Complex note: It is convenient to define a Hessian matrix action in terms of +# the *conjugate* of the action, i.e. (H \zeta)^{*,T}, e.g. this is the form +# returned by reverse-over-forward AD. However complex conjugation is then +# needed in a number of places (e.g. one cannot define an eigenproblem directly +# in terms of the conjugate of an action, as this is antilinear, rather than +# linear). + + +class HessianMatrix(Matrix): + def __init__(self, H, M): + if is_var(M): + M = (M,) + else: + M = tuple(M) + space = tuple(map(var_space, M)) + + super().__init__(space, space) + self._H = H + self._M = M + + def mult_add(self, x, y): + if is_var(x): + x = (x,) + if is_var(y): + y = (y,) + + if len(x) != len(self._M): + raise ValueError("Invalid Hessian argument") + for x_i, m in zip(x, self._M): + check_space_types(x_i, m) + + _, _, ddJ = self._H.action(self._M, x) + + if len(y) != len(ddJ): + raise ValueError("Invalid Hessian action") + for y_i, ddJ_i in zip(y, ddJ): + var_axpy_conjugate(y_i, 1.0, ddJ_i) + + +class HessianSystem(System): + """Defines a linear system involving a Hessian matrix, + + .. math:: + + H u = b. + + :arg H: A :class:`tlm_adjoint.hessian.Hessian` defining :math:`H`. + :arg M: A backend `Function`, or a :class:`Sequence` of backend `Function` + objects, defining the control. + :arg nullspace: A + :class:`tlm_adjoint._code_generator.block_system.Nullspace` or a + :class:`Sequence` of + :class:`tlm_adjoint._code_generator.block_system.Nullspace` objects + defining the nullspace and left nullspace of the Hessian matrix. `None` + indicates a + :class:`tlm_adjoint._code_generator.block_system.NoneNullspace`. + :arg comm: MPI communicator. + """ + + def __init__(self, H, M, *, + nullspace=None, comm=None): + if is_var(M): + M = (M,) + + arg_spaces = MixedSpace( + (tuple(map(var_space, M)),), + space_types=tuple(map(var_space_type, M))) + action_spaces = MixedSpace( + (tuple(map(var_space, M)),), + space_types=tuple(var_space_type(m, rel_space_type="dual") + for m in M)) + + matrix = HessianMatrix(H, M) + + if comm is None: + comm = space_comm(arg_spaces.mixed_space()) + comm = comm_dup_cached(comm, key="HessianSystem") + + super().__init__( + arg_spaces, action_spaces, matrix, + nullspaces=BlockNullspace(nullspace), comm=comm) + + @manager_disabled() + def solve(self, u, b, **kwargs): + """Solve a linear system involving a Hessian matrix, + + .. math:: + + H u = b. + + :arg u: A backend `Function`, or a :class:`Sequence` of backend + `Function` objects, defining the solution :math:`u`. + :arg b: A backend `Function`, or a :class:`Sequence` of backend + `Function` objects, defining the conjugate of the right-hand-side + :math:`b`. + + Remaining arguments are handed to the base class :meth:`solve` method. + """ + + if is_var(b): + b = var_copy_conjugate(b) + else: + b = tuple_sub(map(var_copy_conjugate, iter_sub(b)), b) + return super().solve(u, b, **kwargs) + + +def hessian_eigendecompose( + H, m, B_inv_action, B_action, *, + nullspace=None, problem_type=None, pre_callback=None, + correct_eigenvectors=True, **kwargs): + r"""Interface with SLEPc via slepc4py, for the matrix free solution of + generalized eigenproblems + + .. math:: + + H v = \lambda B^{-1} v, + + where :math:`H` is a Hessian matrix. + + Despite the notation :math:`B^{-1}` may be singular, defining an inverse + operator only on an appropriate subspace. + + :arg H: A :class:`tlm_adjoint.hessian.Hessian`. + :arg m: A backend `Function` defining the control. + :arg B_inv_action: A callable accepting a backend `Function` defining `v` + and computing the conjugate of the action of :math:`B^{-1}` on + :math:`v`, returning the result as a backend `Function`. + :arg B_action: A callable accepting a backend `Function` defining :math:`v` + and computing the action of :math:`B` on the conjugate of :math:`v`, + returning the result as a backend `Function`. + :arg nullspace: A + :class:`tlm_adjoint._code_generator.block_system.Nullspace` defining + the nullspace and left nullspace of :math:`H` and :math:`B^{-1}`. + :arg problem_type: The eigenproblem type -- see + :class:`slepc4py.SLEPc.EPS.ProblemType`. Defaults to + `slepc4py.SLEPc.EPS.ProblemType.GHEP` in the real case and + `slepc4py.SLEPc.EPS.ProblemType.GNHEP` in the complex case. + :arg pre_callback: A callable accepting a single + :class:`slepc4py.SLEPc.EPS` argument. Used for detailed manual + configuration. Called after all other configuration options are set, + but before the :meth:`EPS.setUp` method is called. + :arg correct_eigenvectors: Whether to apply a nullspace correction to the + eigenvectors. + + Remaining keyword arguments are passed to + :func:`tlm_adjoint.eigendecomposition.eigendecompose`. + """ + + space = var_space(m) + + arg_space_type = var_space_type(m) + arg_space = MixedSpace(space, space_types=arg_space_type) + assert arg_space.split_space() == (space,) + assert arg_space.flattened_space() == (space,) + assert arg_space.mixed_space() == space + + action_space_type = var_space_type(m, rel_space_type="dual") + action_space = MixedSpace(space, space_types=action_space_type) + assert action_space.split_space() == (space,) + assert action_space.flattened_space() == (space,) + assert action_space.mixed_space() == space + + if nullspace is None: + nullspace = NoneNullspace() + + def H_action(x): + x = var_copy(x) + nullspace.pre_mult_correct_lhs(x) + _, _, y = H.action(m, x) + y = var_copy_conjugate(y) + nullspace.post_mult_correct_lhs(None, y) + return y + + B_inv_action_arg = B_inv_action + + def B_inv_action(x): + x = var_copy(x) + nullspace.pre_mult_correct_lhs(x) + y = B_inv_action_arg(var_copy(x)) + y = var_copy_conjugate(y) + nullspace.post_mult_correct_lhs(x, y) + return y + + B_action_arg = B_action + + def B_action(x, y): + x, = x + y, = tuple(map(var_copy_conjugate, y)) + # Nullspace corrections applied by the Preconditioner class + var_assign(x, B_action_arg(y)) + + pre_callback_arg = pre_callback + + def pre_callback(eps): + _, B_inv = eps.getOperators() + ksp_solver = eps.getST().getKSP() + + B_pc = Preconditioner( + action_space, arg_space, + B_action, BlockNullspace(nullspace)) + pc = PETSc.PC().createPython( + B_pc, comm=ksp_solver.comm) + pc.setOperators(B_inv) + pc.setUp() + + ksp_solver.setType(PETSc.KSP.Type.PREONLY) + ksp_solver.setTolerances(rtol=0.0, atol=0.0, divtol=None, max_it=1) + ksp_solver.setPC(pc) + ksp_solver.setUp() + + if hasattr(eps, "setPurify"): + eps.setPurify(False) + else: + warnings.warn("slepc4py.SLEPc.EPS.setPurify not available", + RuntimeWarning) + + if pre_callback_arg is not None: + pre_callback_arg(eps) + + if problem_type is None: + import slepc4py.SLEPc as SLEPc + if issubclass(space_dtype(space), (float, np.floating)): + problem_type = SLEPc.EPS.ProblemType.GHEP + else: + problem_type = SLEPc.EPS.ProblemType.GNHEP + + Lam, V = eigendecompose( + space, H_action, B_action=B_inv_action, arg_space_type=arg_space_type, + action_space_type=action_space_type, problem_type=problem_type, + pre_callback=pre_callback, **kwargs) + + if correct_eigenvectors: + if len(V) == 2 \ + and isinstance(V[0], Sequence) \ + and isinstance(V[1], Sequence): + assert len(V[0]) == len(V[1]) + assert len(V[0]) == len(Lam) + for V_r, V_i in zip(*V): + nullspace.correct_soln(V_r) + nullspace.correct_soln(V_i) + else: + assert len(V) == len(Lam) + for V_r in V: + nullspace.correct_soln(V_r) + + return Lam, V + + +def B_inv_orthonormality_test(V, B_inv_action): + """Check for :math:`B^{-1}`-orthonormality. + + Requires real spaces. + + :arg B_inv_action: A callable accepting a backend `Function` defining `v` + and computing the action of :math:`B^{-1}` on :math:`v`, returning the + result as a backend `Function`. + :arg V: A :class:`Sequence` of backend `Function` objects to test for + :math:`B^{-1}`-orthonormality. + :returns: A :class:`tuple` `(max_diagonal_error_norm, + max_off_diagonal_error_norm)` with + + - `max_diagonal_error_norm`: The maximum :math:`B^{-1}` + normalization error magnitude. + - `max_diagonal_error_norm`: The maximum :math:`B^{-1}` + orthogonality error magnitude. + """ + + if len(V) == 2 \ + and isinstance(V[0], Sequence) \ + and isinstance(V[1], Sequence): + raise ValueError("Cannot supply separate real/complex eigenvector " + "components") + + B_inv_V = [] + for v in V: + if not issubclass(var_dtype(v), (float, np.floating)): + raise ValueError("Real dtype required") + B_inv_V.append(B_inv_action(var_copy(v))) + if not issubclass(var_dtype(B_inv_V[-1]), (float, np.floating)): + raise ValueError("Real dtype required") + + max_diagonal_error_norm = 0.0 + max_off_diagonal_error_norm = 0.0 + assert len(V) == len(B_inv_V) + for i, v in enumerate(V): + for j, B_inv_v in enumerate(B_inv_V): + if i == j: + max_diagonal_error_norm = max( + max_diagonal_error_norm, + abs(var_inner(v, B_inv_v) - 1.0)) + else: + max_off_diagonal_error_norm = max( + max_off_diagonal_error_norm, + abs(var_inner(v, B_inv_v))) + + return max_diagonal_error_norm, max_off_diagonal_error_norm + + +def hessian_eigendecomposition_pc(B_action, Lam, V): + r"""Construct a Hessian matrix preconditioner using a partial spectrum + generalized eigendecomposition. Assumes that the Hessian matrix consists of + two terms + + .. math:: + + H = R^{-1} + B^{-1}, + + where :math:`R` and :math:`B` are symmetric. + + Assumes real spaces. Despite the notation :math:`R^{-1}` and :math:`B^{-1}` + (and later :math:`H^{-1}`) may be singular, defining inverse operators only + on an appropriate subspace. :math:`B` is assumed to define a symmetric + positive definite operator on that subspace. + + The approximation is defined via + + .. math:: + + H^{-1} \approx B + V \Lambda \left( I + \Lambda \right)^{-1} V^T + + where + + .. math:: + + R^{-1} V = B^{-1} V \Lambda, + + and where :math:`\Lambda` is a diagonal matrix and :math:`V` has + :math:`B^{-1}`-orthonormal columns, :math:`V^T B^{-1} V = I`. + + This low rank update approximation for the Hessian matrix inverse is + described in + + - Tobin Isaac, Noemi Petra, Georg Stadler, and Omar Ghattas, 'Scalable + and efficient algorithms for the propagation of uncertainty from data + through inference to prediction for large-scale problems, with + application to flow of the Antarctic ice sheet', Journal of + Computational Physics, 296, pp. 348--368, 2015, doi: + 10.1016/j.jcp.2015.04.047 + + See in particular their equation (20). + + :arg B_action: A callable accepting a backend `Function` defining :math:`v` + and computing the action of :math:`B` on :math:`v`, returning the + result as a backend `Function`. + :arg Lam: A :class:`Sequence` defining the diagonal of :math:`\Lambda`. + :arg V: A :class:`Sequence` of backend `Function` objects defining the + columns of :math:`V`. + :returns: A callable suitable for use as the `pc_fn` argument to + :meth:`HessianSystem.solve`. + """ + + if len(V) == 2 \ + and isinstance(V[0], Sequence) \ + and isinstance(V[1], Sequence): + raise ValueError("Cannot supply separate real/complex eigenvector " + "components") + + Lam = tuple(Lam) + V = tuple(V) + if len(Lam) != len(V): + raise ValueError("Invalid eigenpairs") + + def pc_fn(u, b): + b = var_copy_conjugate(b) + var_assign(u, B_action(var_copy(b))) + + assert len(Lam) == len(V) + for lam, v in zip(Lam, V): + alpha = -(lam / (1.0 + lam)) * var_inner(b, v) + var_axpy(u, alpha, v) + + return pc_fn From 1061c50d097326ce80415837d96a4e4c094a79c8 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Mon, 16 Oct 2023 09:58:19 +0100 Subject: [PATCH 2/3] Remove backend dependent code --- tlm_adjoint/fenics/block_system.py | 250 ++++++++------------------ tlm_adjoint/fenics/caches.py | 7 +- tlm_adjoint/fenics/equations.py | 11 +- tlm_adjoint/fenics/functions.py | 121 +------------ tlm_adjoint/firedrake/block_system.py | 243 +++++++++---------------- tlm_adjoint/firedrake/caches.py | 7 +- tlm_adjoint/firedrake/equations.py | 6 +- tlm_adjoint/firedrake/functions.py | 30 ++-- 8 files changed, 188 insertions(+), 487 deletions(-) diff --git a/tlm_adjoint/fenics/block_system.py b/tlm_adjoint/fenics/block_system.py index 460a6958..39438610 100644 --- a/tlm_adjoint/fenics/block_system.py +++ b/tlm_adjoint/fenics/block_system.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -r"""This module is used by both the FEniCS and Firedrake backends, and -implements solvers for linear systems defined in mixed spaces. +r"""This module implements solvers for linear systems defined in mixed spaces. The :class:`System` class defines the block structure of the linear system, and solves the system using an outer Krylov solver. A custom preconditioner can be @@ -82,22 +81,11 @@ with other tlm_adjoint code, space type warnings may be encountered. """ -# This is the only import from other tlm_adjoint modules that is permitted in -# this module -from .backend import backend - -if backend == "Firedrake": - from firedrake import ( - Cofunction, Constant, DirichletBC, Function, FunctionSpace, - TestFunction, assemble) - from firedrake.functionspaceimpl import WithGeometry as FunctionSpaceBase -elif backend == "FEniCS": - from fenics import ( - Constant, DirichletBC, Function, FunctionSpace, TestFunction, assemble) - from fenics import FunctionAssigner, as_backend_type - from dolfin.cpp.function import Constant as cpp_Constant -else: - raise ImportError(f"Unexpected backend: {backend}") +from fenics import ( + Constant, DirichletBC, Function, FunctionSpace, TestFunction, assemble) +from fenics import FunctionAssigner, as_backend_type +from dolfin.cpp.function import Constant as cpp_Constant + import petsc4py.PETSc as PETSc import ufl @@ -327,162 +315,78 @@ def split_to_mixed(self, u_fn, u): raise NotImplementedError -if backend == "Firedrake": - def mesh_comm(mesh): - return mesh.comm - - class BackendMixedSpace(MixedSpace): - def __init__(self, spaces): - if isinstance(spaces, Sequence): - spaces = tuple(spaces) - else: - spaces = (spaces,) - spaces = tuple_sub(spaces, spaces) - super().__init__( - tuple_sub((space if isinstance(space, FunctionSpaceBase) - else space.dual() for space in iter_sub(spaces)), - spaces)) - self._primal_dual_spaces = tuple(iter_sub(spaces)) - assert len(self._primal_dual_spaces) == len(self._flattened_spaces) - - def new_split(self): - u = [] - for space in self._primal_dual_spaces: - if isinstance(space, FunctionSpaceBase): - u.append(Function(space)) - else: - u.append(Cofunction(space)) - return tuple_sub(u, self._spaces) - - @staticmethod - def _iter_sub_fn(iterable): - def expand(e): - if isinstance(e, (Cofunction, Function)): - space = e.function_space() - if hasattr(space, "num_sub_spaces"): - return tuple(e.sub(i) - for i in range(space.num_sub_spaces())) - return e - - return iter_sub(iterable, expand=expand) - - def mixed_to_split(self, u, u_fn): - if len(self._flattened_spaces) == 1: - u, = tuple(iter_sub(u)) - with u_fn.dat.vec_ro as u_fn_v, u.dat.vec_wo as u_v: - u_fn_v.copy(result=u_v) - else: - for i, u_i in enumerate(self._iter_sub_fn(u)): - with u_fn.sub(i).dat.vec_ro as u_fn_i_v, u_i.dat.vec_wo as u_i_v: # noqa: E501 - u_fn_i_v.copy(result=u_i_v) - - def split_to_mixed(self, u_fn, u): - if len(self._flattened_spaces) == 1: - u, = tuple(iter_sub(u)) - with u_fn.dat.vec_wo as u_fn_v, u.dat.vec_ro as u_v: - u_v.copy(result=u_fn_v) - else: - for i, u_i in enumerate(self._iter_sub_fn(u)): - with u_fn.sub(i).dat.vec_wo as u_fn_i_v, u_i.dat.vec_ro as u_i_v: # noqa: E501 - u_i_v.copy(result=u_fn_i_v) - - def mat(a): - return a.petscmat - - @contextmanager - def vec(u, mode=Access.RW): - attribute_name = {Access.RW: "vec", - Access.READ: "vec_ro", - Access.WRITE: "vec_wo"}[mode] - with getattr(u.dat, attribute_name) as u_v: - yield u_v - - def bc_space(bc): - return bc.function_space() - - def bc_is_homogeneous(bc): - return isinstance(bc.function_arg, ufl.classes.Zero) - - def bc_domain_args(bc): - return (bc.sub_domain,) - - def apply_bcs(u, bcs): - if not isinstance(bcs, Sequence): - bcs = (bcs,) - if len(bcs) > 0 and not isinstance(u.function_space(), type(bcs[0].function_space())): # noqa: E501 - u_bc = u.riesz_representation("l2") +def mesh_comm(mesh): + return mesh.mpi_comm() + + +class BackendMixedSpace(MixedSpace): + @cached_property + def _mixed_to_split_assigner(self): + return FunctionAssigner(list(self._flattened_spaces), + self._mixed_space) + + @cached_property + def _split_to_mixed_assigner(self): + return FunctionAssigner(self._mixed_space, + list(self._flattened_spaces)) + + def mixed_to_split(self, u, u_fn): + if len(self._flattened_spaces) == 1: + u, = tuple(iter_sub(u)) + u.assign(u_fn) else: - u_bc = u - for bc in bcs: - bc.apply(u_bc) -elif backend == "FEniCS": - def mesh_comm(mesh): - return mesh.mpi_comm() - - class BackendMixedSpace(MixedSpace): - @cached_property - def _mixed_to_split_assigner(self): - return FunctionAssigner(list(self._flattened_spaces), - self._mixed_space) - - @cached_property - def _split_to_mixed_assigner(self): - return FunctionAssigner(self._mixed_space, - list(self._flattened_spaces)) - - def mixed_to_split(self, u, u_fn): - if len(self._flattened_spaces) == 1: - u, = tuple(iter_sub(u)) - u.assign(u_fn) - else: - self._mixed_to_split_assigner.assign(list(iter_sub(u)), - u_fn) - - def split_to_mixed(self, u_fn, u): - if len(self._flattened_spaces) == 1: - u, = tuple(iter_sub(u)) - u_fn.assign(u) - else: - self._split_to_mixed_assigner.assign(u_fn, - list(iter_sub(u))) - - def mat(a): - matrix = as_backend_type(a).mat() - if not isinstance(matrix, PETSc.Mat): - raise RuntimeError("PETSc backend required") - return matrix - - @contextmanager - def vec(u, mode=Access.RW): - if isinstance(u, Function): - u = u.vector() - u_v = as_backend_type(u).vec() - if not isinstance(u_v, PETSc.Vec): - raise RuntimeError("PETSc backend required") - - yield u_v - - if mode in {Access.RW, Access.WRITE}: - u.update_ghost_values() - - def bc_space(bc): - return FunctionSpace(bc.function_space()) - - def bc_is_homogeneous(bc): - # A weaker check with FEniCS, as the constant might be modified - return (isinstance(bc.value(), cpp_Constant) - and (bc.value().values() == 0.0).all()) - - def bc_domain_args(bc): - return (bc.sub_domain, bc.method()) - - def apply_bcs(u, bcs): - if not isinstance(bcs, Sequence): - bcs = (bcs,) - for bc in bcs: - bc.apply(u.vector()) -else: - raise ImportError(f"Unexpected backend: {backend}") + self._mixed_to_split_assigner.assign(list(iter_sub(u)), + u_fn) + + def split_to_mixed(self, u_fn, u): + if len(self._flattened_spaces) == 1: + u, = tuple(iter_sub(u)) + u_fn.assign(u) + else: + self._split_to_mixed_assigner.assign(u_fn, + list(iter_sub(u))) + + +def mat(a): + matrix = as_backend_type(a).mat() + if not isinstance(matrix, PETSc.Mat): + raise RuntimeError("PETSc backend required") + return matrix + + +@contextmanager +def vec(u, mode=Access.RW): + if isinstance(u, Function): + u = u.vector() + u_v = as_backend_type(u).vec() + if not isinstance(u_v, PETSc.Vec): + raise RuntimeError("PETSc backend required") + + yield u_v + + if mode in {Access.RW, Access.WRITE}: + u.update_ghost_values() + + +def bc_space(bc): + return FunctionSpace(bc.function_space()) + + +def bc_is_homogeneous(bc): + # Note that the constant might be modified + return (isinstance(bc.value(), cpp_Constant) + and (bc.value().values() == 0.0).all()) + + +def bc_domain_args(bc): + return (bc.sub_domain, bc.method()) + + +def apply_bcs(u, bcs): + if not isinstance(bcs, Sequence): + bcs = (bcs,) + for bc in bcs: + bc.apply(u.vector()) class Nullspace(ABC): diff --git a/tlm_adjoint/fenics/caches.py b/tlm_adjoint/fenics/caches.py index 30339d61..3cfcb598 100644 --- a/tlm_adjoint/fenics/caches.py +++ b/tlm_adjoint/fenics/caches.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -"""This module is used by both the FEniCS and Firedrake backends, and -implements finite element assembly and linear solver data caching. +"""This module implements finite element assembly and linear solver data +caching. """ from .backend import TrialFunction, backend_DirichletBC, backend_Function @@ -311,8 +311,7 @@ def assemble(self, form, *, :arg bcs: Dirichlet boundary conditions. :arg form_compiler_parameters: Form compiler parameters. :arg linear_solver_parameters: Linear solver parameters. Required for - assembly parameters which appear in the linear solver parameters - -- in particular the Firedrake `'mat_type'` parameter. + assembly parameters which appear in the linear solver parameters. :arg replace_map: A :class:`Mapping` defining a map from symbolic variables to values. :returns: A :class:`tuple` `(value_ref, value)`, where `value` is the diff --git a/tlm_adjoint/fenics/equations.py b/tlm_adjoint/fenics/equations.py index 58e039e5..9ae01457 100644 --- a/tlm_adjoint/fenics/equations.py +++ b/tlm_adjoint/fenics/equations.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -"""This module is used by both the FEniCS and Firedrake backends, and -implements finite element calculations. In particular the +"""This module implements finite element calculations. In particular the :class:`EquationSolver` class implements the solution of finite element variational problems. """ @@ -812,13 +811,13 @@ def __init__(self, x, rhs, *args, **kwargs): class DirichletBCApplication(Equation): r"""Represents the application of a Dirichlet boundary condition to a zero - valued backend `Function`. Specifically, with the Firedrake backend this - represents: + valued backend `Function`. Specifically this represents: .. code-block:: python - x.zero() - DirichletBC(x.function_space(), y, *args, **kwargs).apply(x) + x.vector().zero() + DirichletBC(x.function_space(), y, + *args, **kwargs).apply(x.vector()) The forward residual :math:`\mathcal{F}` is defined so that :math:`\partial \mathcal{F} / \partial x` is the identity. diff --git a/tlm_adjoint/fenics/functions.py b/tlm_adjoint/fenics/functions.py index 43cbb66e..cd65c86a 100644 --- a/tlm_adjoint/fenics/functions.py +++ b/tlm_adjoint/fenics/functions.py @@ -1,9 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -"""This module is used by both the FEniCS and Firedrake backends, and includes -functionality for handling :class:`ufl.Coefficient` objects and boundary -conditions. +"""This module includes functionality for handling :class:`ufl.Coefficient` +objects and boundary conditions. """ from .backend import ( @@ -310,26 +309,6 @@ def __init__(self, value=None, *args, name=None, domain=None, space=None, self._tlm_adjoint__var_interface_attrs.d_setitem("static", static) self._tlm_adjoint__var_interface_attrs.d_setitem("cache", cache) - def __new__(cls, value=None, *args, name=None, domain=None, - space_type="primal", shape=None, static=False, cache=None, - **kwargs): - if issubclass(cls, ufl.classes.Coefficient) or domain is None: - return object().__new__(cls) - else: - # For Firedrake - value = constant_value(value, shape) - if space_type not in {"primal", "conjugate", - "dual", "conjugate_dual"}: - raise ValueError("Invalid space type") - if cache is None: - cache = static - F = super().__new__(cls, value, domain=domain) - F.rename(name=name) - F._tlm_adjoint__var_interface_attrs.d_setitem("space_type", space_type) # noqa: E501 - F._tlm_adjoint__var_interface_attrs.d_setitem("static", static) - F._tlm_adjoint__var_interface_attrs.d_setitem("cache", cache) - return F - class Zero: """Mixin for defining a zero-valued variable. Used for zero-valued @@ -364,99 +343,17 @@ def __init__(self, *, name=None, domain=None, space=None, if var_linf_norm(self) != 0.0: raise RuntimeError("ZeroConstant is not zero-valued") - def __new__(cls, *args, shape=None, **kwargs): - return Constant.__new__( - cls, constant_value(shape=shape), *args, - shape=shape, static=True, cache=True, **kwargs) - def assign(self, *args, **kwargs): raise RuntimeError("Cannot call assign method of ZeroConstant") -def as_coefficient(x): - if isinstance(x, ufl.classes.Coefficient): - return x - - # For Firedrake - - if not isinstance(x, backend_Constant): - raise TypeError("Unexpected type") - - if not hasattr(x, "_tlm_adjoint__Coefficient"): - if is_var(x): - space = var_space(x) - else: - if len(x.ufl_shape) == 0: - element = ufl.classes.FiniteElement("R", None, 0) - elif len(x.ufl_shape) == 1: - element = ufl.classes.VectorElement("R", None, 0, - dim=x.ufl_shape[0]) - else: - element = ufl.classes.TensorElement("R", None, 0, - shape=x.ufl_shape) - space = ufl.classes.FunctionSpace(None, element) - - x._tlm_adjoint__Coefficient = ufl.classes.Coefficient(space) - - return x._tlm_adjoint__Coefficient - - -def with_coefficient(expr, x): - x_coeff = as_coefficient(x) - if x_coeff is x: - return expr, {}, {} - else: - # For Firedrake - replace_map = {x: x_coeff} - replace_map_inverse = {x_coeff: x} - return ufl.replace(expr, replace_map), replace_map, replace_map_inverse - - -def with_coefficients(expr): - if isinstance(expr, ufl.classes.Form) \ - and "_tlm_adjoint__form_with_coefficients" in expr._cache: - return expr._cache["_tlm_adjoint__form_with_coefficients"] - - if issubclass(backend_Constant, ufl.classes.Coefficient): - replace_map = {} - else: - # For Firedrake - constants = tuple(sorted(ufl.algorithms.extract_type(expr, backend_Constant), # noqa: E501 - key=lambda c: c.count())) - replace_map = dict(zip(constants, map(as_coefficient, constants))) - replace_map_inverse = {c_coeff: c - for c, c_coeff in replace_map.items()} - - expr_with_coeffs = ufl.replace(expr, replace_map) - if isinstance(expr, ufl.classes.Form): - expr._cache["_tlm_adjoint__form_with_coefficients"] = \ - (expr_with_coeffs, replace_map, replace_map_inverse) - return expr_with_coeffs, replace_map, replace_map_inverse - - def extract_coefficients(expr): """ :returns: Variables on which the supplied :class:`ufl.core.expr.Expr` or :class:`ufl.Form` depends. """ - if isinstance(expr, ufl.classes.Form) \ - and "_tlm_adjoint__form_coefficients" in expr._cache: - return expr._cache["_tlm_adjoint__form_coefficients"] - - if issubclass(backend_Constant, ufl.classes.Coefficient): - cls = (ufl.classes.Coefficient,) - else: - # For Firedrake - cls = (ufl.classes.Coefficient, backend_Constant) - deps = [] - for c in cls: - deps.extend(sorted(ufl.algorithms.extract_type(expr, c), - key=lambda dep: dep.count())) - - if isinstance(expr, ufl.classes.Form): - expr._cache["_tlm_adjoint__form_coefficients"] = deps - return deps + return ufl.algorithms.extract_coefficients(expr) def derivative(expr, x, argument=None, *, @@ -474,17 +371,7 @@ def derivative(expr, x, argument=None, *, if isinstance(argument, ufl.classes.Argument) and argument.number() < arity: # noqa: E501 raise ValueError("Invalid argument") - if isinstance(expr, ufl.classes.Expr): - expr, replace_map, replace_map_inverse = with_coefficient(expr, x) - else: - expr, replace_map, replace_map_inverse = with_coefficients(expr) - - if argument is not None: - argument, _, argument_replace_map_inverse = with_coefficients(argument) - replace_map_inverse.update(argument_replace_map_inverse) - - dexpr = ufl.derivative(expr, replace_map.get(x, x), argument=argument) - return ufl.replace(dexpr, replace_map_inverse) + return ufl.derivative(expr, x, argument=argument) def eliminate_zeros(expr, *, force_non_empty_form=False): diff --git a/tlm_adjoint/firedrake/block_system.py b/tlm_adjoint/firedrake/block_system.py index 460a6958..40cb6181 100644 --- a/tlm_adjoint/firedrake/block_system.py +++ b/tlm_adjoint/firedrake/block_system.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -r"""This module is used by both the FEniCS and Firedrake backends, and -implements solvers for linear systems defined in mixed spaces. +r"""This module implements solvers for linear systems defined in mixed spaces. The :class:`System` class defines the block structure of the linear system, and solves the system using an outer Krylov solver. A custom preconditioner can be @@ -82,22 +81,10 @@ with other tlm_adjoint code, space type warnings may be encountered. """ -# This is the only import from other tlm_adjoint modules that is permitted in -# this module -from .backend import backend - -if backend == "Firedrake": - from firedrake import ( - Cofunction, Constant, DirichletBC, Function, FunctionSpace, - TestFunction, assemble) - from firedrake.functionspaceimpl import WithGeometry as FunctionSpaceBase -elif backend == "FEniCS": - from fenics import ( - Constant, DirichletBC, Function, FunctionSpace, TestFunction, assemble) - from fenics import FunctionAssigner, as_backend_type - from dolfin.cpp.function import Constant as cpp_Constant -else: - raise ImportError(f"Unexpected backend: {backend}") +from firedrake import ( + Cofunction, Constant, DirichletBC, Function, FunctionSpace, TestFunction, + assemble) +from firedrake.functionspaceimpl import WithGeometry as FunctionSpaceBase import petsc4py.PETSc as PETSc import ufl @@ -107,7 +94,7 @@ from collections.abc import Sequence from contextlib import contextmanager from enum import Enum -from functools import cached_property, wraps +from functools import wraps import logging __all__ = \ @@ -327,162 +314,100 @@ def split_to_mixed(self, u_fn, u): raise NotImplementedError -if backend == "Firedrake": - def mesh_comm(mesh): - return mesh.comm +def mesh_comm(mesh): + return mesh.comm - class BackendMixedSpace(MixedSpace): - def __init__(self, spaces): - if isinstance(spaces, Sequence): - spaces = tuple(spaces) - else: - spaces = (spaces,) - spaces = tuple_sub(spaces, spaces) - super().__init__( - tuple_sub((space if isinstance(space, FunctionSpaceBase) - else space.dual() for space in iter_sub(spaces)), - spaces)) - self._primal_dual_spaces = tuple(iter_sub(spaces)) - assert len(self._primal_dual_spaces) == len(self._flattened_spaces) - - def new_split(self): - u = [] - for space in self._primal_dual_spaces: - if isinstance(space, FunctionSpaceBase): - u.append(Function(space)) - else: - u.append(Cofunction(space)) - return tuple_sub(u, self._spaces) - - @staticmethod - def _iter_sub_fn(iterable): - def expand(e): - if isinstance(e, (Cofunction, Function)): - space = e.function_space() - if hasattr(space, "num_sub_spaces"): - return tuple(e.sub(i) - for i in range(space.num_sub_spaces())) - return e - - return iter_sub(iterable, expand=expand) - - def mixed_to_split(self, u, u_fn): - if len(self._flattened_spaces) == 1: - u, = tuple(iter_sub(u)) - with u_fn.dat.vec_ro as u_fn_v, u.dat.vec_wo as u_v: - u_fn_v.copy(result=u_v) - else: - for i, u_i in enumerate(self._iter_sub_fn(u)): - with u_fn.sub(i).dat.vec_ro as u_fn_i_v, u_i.dat.vec_wo as u_i_v: # noqa: E501 - u_fn_i_v.copy(result=u_i_v) - - def split_to_mixed(self, u_fn, u): - if len(self._flattened_spaces) == 1: - u, = tuple(iter_sub(u)) - with u_fn.dat.vec_wo as u_fn_v, u.dat.vec_ro as u_v: - u_v.copy(result=u_fn_v) - else: - for i, u_i in enumerate(self._iter_sub_fn(u)): - with u_fn.sub(i).dat.vec_wo as u_fn_i_v, u_i.dat.vec_ro as u_i_v: # noqa: E501 - u_i_v.copy(result=u_fn_i_v) - def mat(a): - return a.petscmat +class BackendMixedSpace(MixedSpace): + def __init__(self, spaces): + if isinstance(spaces, Sequence): + spaces = tuple(spaces) + else: + spaces = (spaces,) + spaces = tuple_sub(spaces, spaces) + super().__init__( + tuple_sub((space if isinstance(space, FunctionSpaceBase) + else space.dual() for space in iter_sub(spaces)), + spaces)) + self._primal_dual_spaces = tuple(iter_sub(spaces)) + assert len(self._primal_dual_spaces) == len(self._flattened_spaces) - @contextmanager - def vec(u, mode=Access.RW): - attribute_name = {Access.RW: "vec", - Access.READ: "vec_ro", - Access.WRITE: "vec_wo"}[mode] - with getattr(u.dat, attribute_name) as u_v: - yield u_v + def new_split(self): + u = [] + for space in self._primal_dual_spaces: + if isinstance(space, FunctionSpaceBase): + u.append(Function(space)) + else: + u.append(Cofunction(space)) + return tuple_sub(u, self._spaces) - def bc_space(bc): - return bc.function_space() + @staticmethod + def _iter_sub_fn(iterable): + def expand(e): + if isinstance(e, (Cofunction, Function)): + space = e.function_space() + if hasattr(space, "num_sub_spaces"): + return tuple(e.sub(i) + for i in range(space.num_sub_spaces())) + return e - def bc_is_homogeneous(bc): - return isinstance(bc.function_arg, ufl.classes.Zero) + return iter_sub(iterable, expand=expand) - def bc_domain_args(bc): - return (bc.sub_domain,) + def mixed_to_split(self, u, u_fn): + if len(self._flattened_spaces) == 1: + u, = tuple(iter_sub(u)) + with u_fn.dat.vec_ro as u_fn_v, u.dat.vec_wo as u_v: + u_fn_v.copy(result=u_v) + else: + for i, u_i in enumerate(self._iter_sub_fn(u)): + with u_fn.sub(i).dat.vec_ro as u_fn_i_v, u_i.dat.vec_wo as u_i_v: # noqa: E501 + u_fn_i_v.copy(result=u_i_v) - def apply_bcs(u, bcs): - if not isinstance(bcs, Sequence): - bcs = (bcs,) - if len(bcs) > 0 and not isinstance(u.function_space(), type(bcs[0].function_space())): # noqa: E501 - u_bc = u.riesz_representation("l2") + def split_to_mixed(self, u_fn, u): + if len(self._flattened_spaces) == 1: + u, = tuple(iter_sub(u)) + with u_fn.dat.vec_wo as u_fn_v, u.dat.vec_ro as u_v: + u_v.copy(result=u_fn_v) else: - u_bc = u - for bc in bcs: - bc.apply(u_bc) -elif backend == "FEniCS": - def mesh_comm(mesh): - return mesh.mpi_comm() - - class BackendMixedSpace(MixedSpace): - @cached_property - def _mixed_to_split_assigner(self): - return FunctionAssigner(list(self._flattened_spaces), - self._mixed_space) - - @cached_property - def _split_to_mixed_assigner(self): - return FunctionAssigner(self._mixed_space, - list(self._flattened_spaces)) - - def mixed_to_split(self, u, u_fn): - if len(self._flattened_spaces) == 1: - u, = tuple(iter_sub(u)) - u.assign(u_fn) - else: - self._mixed_to_split_assigner.assign(list(iter_sub(u)), - u_fn) + for i, u_i in enumerate(self._iter_sub_fn(u)): + with u_fn.sub(i).dat.vec_wo as u_fn_i_v, u_i.dat.vec_ro as u_i_v: # noqa: E501 + u_i_v.copy(result=u_fn_i_v) - def split_to_mixed(self, u_fn, u): - if len(self._flattened_spaces) == 1: - u, = tuple(iter_sub(u)) - u_fn.assign(u) - else: - self._split_to_mixed_assigner.assign(u_fn, - list(iter_sub(u))) - - def mat(a): - matrix = as_backend_type(a).mat() - if not isinstance(matrix, PETSc.Mat): - raise RuntimeError("PETSc backend required") - return matrix - - @contextmanager - def vec(u, mode=Access.RW): - if isinstance(u, Function): - u = u.vector() - u_v = as_backend_type(u).vec() - if not isinstance(u_v, PETSc.Vec): - raise RuntimeError("PETSc backend required") +def mat(a): + return a.petscmat + + +@contextmanager +def vec(u, mode=Access.RW): + attribute_name = {Access.RW: "vec", + Access.READ: "vec_ro", + Access.WRITE: "vec_wo"}[mode] + with getattr(u.dat, attribute_name) as u_v: yield u_v - if mode in {Access.RW, Access.WRITE}: - u.update_ghost_values() - def bc_space(bc): - return FunctionSpace(bc.function_space()) +def bc_space(bc): + return bc.function_space() - def bc_is_homogeneous(bc): - # A weaker check with FEniCS, as the constant might be modified - return (isinstance(bc.value(), cpp_Constant) - and (bc.value().values() == 0.0).all()) - def bc_domain_args(bc): - return (bc.sub_domain, bc.method()) +def bc_is_homogeneous(bc): + return isinstance(bc.function_arg, ufl.classes.Zero) - def apply_bcs(u, bcs): - if not isinstance(bcs, Sequence): - bcs = (bcs,) - for bc in bcs: - bc.apply(u.vector()) -else: - raise ImportError(f"Unexpected backend: {backend}") + +def bc_domain_args(bc): + return (bc.sub_domain,) + + +def apply_bcs(u, bcs): + if not isinstance(bcs, Sequence): + bcs = (bcs,) + if len(bcs) > 0 and not isinstance(u.function_space(), type(bcs[0].function_space())): # noqa: E501 + u_bc = u.riesz_representation("l2") + else: + u_bc = u + for bc in bcs: + bc.apply(u_bc) class Nullspace(ABC): diff --git a/tlm_adjoint/firedrake/caches.py b/tlm_adjoint/firedrake/caches.py index 30339d61..3cfcb598 100644 --- a/tlm_adjoint/firedrake/caches.py +++ b/tlm_adjoint/firedrake/caches.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -"""This module is used by both the FEniCS and Firedrake backends, and -implements finite element assembly and linear solver data caching. +"""This module implements finite element assembly and linear solver data +caching. """ from .backend import TrialFunction, backend_DirichletBC, backend_Function @@ -311,8 +311,7 @@ def assemble(self, form, *, :arg bcs: Dirichlet boundary conditions. :arg form_compiler_parameters: Form compiler parameters. :arg linear_solver_parameters: Linear solver parameters. Required for - assembly parameters which appear in the linear solver parameters - -- in particular the Firedrake `'mat_type'` parameter. + assembly parameters which appear in the linear solver parameters. :arg replace_map: A :class:`Mapping` defining a map from symbolic variables to values. :returns: A :class:`tuple` `(value_ref, value)`, where `value` is the diff --git a/tlm_adjoint/firedrake/equations.py b/tlm_adjoint/firedrake/equations.py index 58e039e5..80d53eab 100644 --- a/tlm_adjoint/firedrake/equations.py +++ b/tlm_adjoint/firedrake/equations.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -"""This module is used by both the FEniCS and Firedrake backends, and -implements finite element calculations. In particular the +"""This module implements finite element calculations. In particular the :class:`EquationSolver` class implements the solution of finite element variational problems. """ @@ -812,8 +811,7 @@ def __init__(self, x, rhs, *args, **kwargs): class DirichletBCApplication(Equation): r"""Represents the application of a Dirichlet boundary condition to a zero - valued backend `Function`. Specifically, with the Firedrake backend this - represents: + valued backend `Function`. Specifically this represents: .. code-block:: python diff --git a/tlm_adjoint/firedrake/functions.py b/tlm_adjoint/firedrake/functions.py index 43cbb66e..4d6ddcaf 100644 --- a/tlm_adjoint/firedrake/functions.py +++ b/tlm_adjoint/firedrake/functions.py @@ -1,9 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -"""This module is used by both the FEniCS and Firedrake backends, and includes -functionality for handling :class:`ufl.Coefficient` objects and boundary -conditions. +"""This module includes functionality for handling :class:`ufl.Coefficient` +objects and boundary conditions. """ from .backend import ( @@ -313,10 +312,10 @@ def __init__(self, value=None, *args, name=None, domain=None, space=None, def __new__(cls, value=None, *args, name=None, domain=None, space_type="primal", shape=None, static=False, cache=None, **kwargs): - if issubclass(cls, ufl.classes.Coefficient) or domain is None: + assert not issubclass(cls, ufl.classes.Coefficient) + if domain is None: return object().__new__(cls) else: - # For Firedrake value = constant_value(value, shape) if space_type not in {"primal", "conjugate", "dual", "conjugate_dual"}: @@ -377,8 +376,6 @@ def as_coefficient(x): if isinstance(x, ufl.classes.Coefficient): return x - # For Firedrake - if not isinstance(x, backend_Constant): raise TypeError("Unexpected type") @@ -406,7 +403,6 @@ def with_coefficient(expr, x): if x_coeff is x: return expr, {}, {} else: - # For Firedrake replace_map = {x: x_coeff} replace_map_inverse = {x_coeff: x} return ufl.replace(expr, replace_map), replace_map, replace_map_inverse @@ -417,13 +413,10 @@ def with_coefficients(expr): and "_tlm_adjoint__form_with_coefficients" in expr._cache: return expr._cache["_tlm_adjoint__form_with_coefficients"] - if issubclass(backend_Constant, ufl.classes.Coefficient): - replace_map = {} - else: - # For Firedrake - constants = tuple(sorted(ufl.algorithms.extract_type(expr, backend_Constant), # noqa: E501 - key=lambda c: c.count())) - replace_map = dict(zip(constants, map(as_coefficient, constants))) + assert not issubclass(backend_Constant, ufl.classes.Coefficient) + constants = tuple(sorted(ufl.algorithms.extract_type(expr, backend_Constant), # noqa: E501 + key=lambda c: c.count())) + replace_map = dict(zip(constants, map(as_coefficient, constants))) replace_map_inverse = {c_coeff: c for c, c_coeff in replace_map.items()} @@ -444,11 +437,8 @@ def extract_coefficients(expr): and "_tlm_adjoint__form_coefficients" in expr._cache: return expr._cache["_tlm_adjoint__form_coefficients"] - if issubclass(backend_Constant, ufl.classes.Coefficient): - cls = (ufl.classes.Coefficient,) - else: - # For Firedrake - cls = (ufl.classes.Coefficient, backend_Constant) + assert not issubclass(backend_Constant, ufl.classes.Coefficient) + cls = (ufl.classes.Coefficient, backend_Constant) deps = [] for c in cls: deps.extend(sorted(ufl.algorithms.extract_type(expr, c), From a5f11ff938d3bfff34e32912d6daa087865cfccd Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Mon, 16 Oct 2023 11:22:34 +0100 Subject: [PATCH 3/3] FEniCS backend: Import ufl_legacy, if available --- tests/fenics/test_block_system.py | 5 ++++- tests/fenics/test_caches.py | 5 ++++- tests/fenics/test_equations.py | 5 ++++- tests/fenics/test_gauss_newton.py | 5 ++++- tests/fenics/test_hessian_system.py | 5 ++++- tlm_adjoint/fenics/backend_code_generator_interface.py | 5 ++++- tlm_adjoint/fenics/backend_interface.py | 5 ++++- tlm_adjoint/fenics/backend_overrides.py | 5 ++++- tlm_adjoint/fenics/block_system.py | 5 ++++- tlm_adjoint/fenics/caches.py | 5 ++++- tlm_adjoint/fenics/equations.py | 5 ++++- tlm_adjoint/fenics/fenics_equations.py | 5 ++++- tlm_adjoint/fenics/functions.py | 5 ++++- 13 files changed, 52 insertions(+), 13 deletions(-) diff --git a/tests/fenics/test_block_system.py b/tests/fenics/test_block_system.py index 8c111aec..f60d6a2c 100644 --- a/tests/fenics/test_block_system.py +++ b/tests/fenics/test_block_system.py @@ -6,7 +6,10 @@ import mpi4py.MPI as MPI # noqa: N817 import numpy as np import pytest -import ufl +try: + import ufl_legacy as ufl +except ImportError: + import ufl pytestmark = pytest.mark.skipif( MPI.COMM_WORLD.size not in {1, 4}, diff --git a/tests/fenics/test_caches.py b/tests/fenics/test_caches.py index 2bea5f87..4fcb3e10 100644 --- a/tests/fenics/test_caches.py +++ b/tests/fenics/test_caches.py @@ -10,7 +10,10 @@ import numpy as np import pytest -import ufl +try: + import ufl_legacy as ufl +except ImportError: + import ufl pytestmark = pytest.mark.skipif( DEFAULT_COMM.size not in {1, 4}, diff --git a/tests/fenics/test_equations.py b/tests/fenics/test_equations.py index c97dc285..ad30f3dc 100644 --- a/tests/fenics/test_equations.py +++ b/tests/fenics/test_equations.py @@ -11,7 +11,10 @@ import numpy as np import os import pytest -import ufl +try: + import ufl_legacy as ufl +except ImportError: + import ufl pytestmark = pytest.mark.skipif( DEFAULT_COMM.size not in {1, 4}, diff --git a/tests/fenics/test_gauss_newton.py b/tests/fenics/test_gauss_newton.py index 3aee00a6..88ebdc53 100644 --- a/tests/fenics/test_gauss_newton.py +++ b/tests/fenics/test_gauss_newton.py @@ -8,7 +8,10 @@ import numpy as np import pytest -import ufl +try: + import ufl_legacy as ufl +except ImportError: + import ufl pytestmark = pytest.mark.skipif( DEFAULT_COMM.size not in {1, 4}, diff --git a/tests/fenics/test_hessian_system.py b/tests/fenics/test_hessian_system.py index 96c99c44..cbd1d54b 100644 --- a/tests/fenics/test_hessian_system.py +++ b/tests/fenics/test_hessian_system.py @@ -10,7 +10,10 @@ import numpy as np import pytest -import ufl +try: + import ufl_legacy as ufl +except ImportError: + import ufl pytestmark = pytest.mark.skipif( DEFAULT_COMM.size not in {1, 4}, diff --git a/tlm_adjoint/fenics/backend_code_generator_interface.py b/tlm_adjoint/fenics/backend_code_generator_interface.py index 3378eff4..9d1485c4 100644 --- a/tlm_adjoint/fenics/backend_code_generator_interface.py +++ b/tlm_adjoint/fenics/backend_code_generator_interface.py @@ -20,7 +20,10 @@ from collections.abc import Sequence import ffc import numpy as np -import ufl +try: + import ufl_legacy as ufl +except ImportError: + import ufl __all__ = \ [ diff --git a/tlm_adjoint/fenics/backend_interface.py b/tlm_adjoint/fenics/backend_interface.py index d89002f2..23f91de4 100644 --- a/tlm_adjoint/fenics/backend_interface.py +++ b/tlm_adjoint/fenics/backend_interface.py @@ -27,7 +27,10 @@ import functools import numpy as np -import ufl +try: + import ufl_legacy as ufl +except ImportError: + import ufl __all__ = \ [ diff --git a/tlm_adjoint/fenics/backend_overrides.py b/tlm_adjoint/fenics/backend_overrides.py index 48ddc142..b20d9fbc 100644 --- a/tlm_adjoint/fenics/backend_overrides.py +++ b/tlm_adjoint/fenics/backend_overrides.py @@ -24,7 +24,10 @@ from .functions import Constant, define_var_alias import numpy as np -import ufl +try: + import ufl_legacy as ufl +except ImportError: + import ufl import warnings import weakref diff --git a/tlm_adjoint/fenics/block_system.py b/tlm_adjoint/fenics/block_system.py index 39438610..00d23235 100644 --- a/tlm_adjoint/fenics/block_system.py +++ b/tlm_adjoint/fenics/block_system.py @@ -88,7 +88,10 @@ import petsc4py.PETSc as PETSc -import ufl +try: + import ufl_legacy as ufl +except ImportError: + import ufl from abc import ABC, abstractmethod from collections import deque diff --git a/tlm_adjoint/fenics/caches.py b/tlm_adjoint/fenics/caches.py index 3cfcb598..118b251f 100644 --- a/tlm_adjoint/fenics/caches.py +++ b/tlm_adjoint/fenics/caches.py @@ -18,7 +18,10 @@ replaced_form) from collections import defaultdict -import ufl +try: + import ufl_legacy as ufl +except ImportError: + import ufl __all__ = \ [ diff --git a/tlm_adjoint/fenics/equations.py b/tlm_adjoint/fenics/equations.py index 9ae01457..e8275650 100644 --- a/tlm_adjoint/fenics/equations.py +++ b/tlm_adjoint/fenics/equations.py @@ -30,7 +30,10 @@ derivative, eliminate_zeros, extract_coefficients) import numpy as np -import ufl +try: + import ufl_legacy as ufl +except ImportError: + import ufl __all__ = \ [ diff --git a/tlm_adjoint/fenics/fenics_equations.py b/tlm_adjoint/fenics/fenics_equations.py index f6b091d5..a5c6180c 100644 --- a/tlm_adjoint/fenics/fenics_equations.py +++ b/tlm_adjoint/fenics/fenics_equations.py @@ -26,7 +26,10 @@ import functools import mpi4py.MPI as MPI import numpy as np -import ufl +try: + import ufl_legacy as ufl +except ImportError: + import ufl __all__ = \ [ diff --git a/tlm_adjoint/fenics/functions.py b/tlm_adjoint/fenics/functions.py index cd65c86a..b96ba481 100644 --- a/tlm_adjoint/fenics/functions.py +++ b/tlm_adjoint/fenics/functions.py @@ -20,7 +20,10 @@ from ..overloaded_float import SymbolicFloat import numpy as np -import ufl +try: + import ufl_legacy as ufl +except ImportError: + import ufl import weakref __all__ = \