From 1061c50d097326ce80415837d96a4e4c094a79c8 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Mon, 16 Oct 2023 09:58:19 +0100 Subject: [PATCH] Remove backend dependent code --- tlm_adjoint/fenics/block_system.py | 250 ++++++++------------------ tlm_adjoint/fenics/caches.py | 7 +- tlm_adjoint/fenics/equations.py | 11 +- tlm_adjoint/fenics/functions.py | 121 +------------ tlm_adjoint/firedrake/block_system.py | 243 +++++++++---------------- tlm_adjoint/firedrake/caches.py | 7 +- tlm_adjoint/firedrake/equations.py | 6 +- tlm_adjoint/firedrake/functions.py | 30 ++-- 8 files changed, 188 insertions(+), 487 deletions(-) diff --git a/tlm_adjoint/fenics/block_system.py b/tlm_adjoint/fenics/block_system.py index 460a6958..39438610 100644 --- a/tlm_adjoint/fenics/block_system.py +++ b/tlm_adjoint/fenics/block_system.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -r"""This module is used by both the FEniCS and Firedrake backends, and -implements solvers for linear systems defined in mixed spaces. +r"""This module implements solvers for linear systems defined in mixed spaces. The :class:`System` class defines the block structure of the linear system, and solves the system using an outer Krylov solver. A custom preconditioner can be @@ -82,22 +81,11 @@ with other tlm_adjoint code, space type warnings may be encountered. """ -# This is the only import from other tlm_adjoint modules that is permitted in -# this module -from .backend import backend - -if backend == "Firedrake": - from firedrake import ( - Cofunction, Constant, DirichletBC, Function, FunctionSpace, - TestFunction, assemble) - from firedrake.functionspaceimpl import WithGeometry as FunctionSpaceBase -elif backend == "FEniCS": - from fenics import ( - Constant, DirichletBC, Function, FunctionSpace, TestFunction, assemble) - from fenics import FunctionAssigner, as_backend_type - from dolfin.cpp.function import Constant as cpp_Constant -else: - raise ImportError(f"Unexpected backend: {backend}") +from fenics import ( + Constant, DirichletBC, Function, FunctionSpace, TestFunction, assemble) +from fenics import FunctionAssigner, as_backend_type +from dolfin.cpp.function import Constant as cpp_Constant + import petsc4py.PETSc as PETSc import ufl @@ -327,162 +315,78 @@ def split_to_mixed(self, u_fn, u): raise NotImplementedError -if backend == "Firedrake": - def mesh_comm(mesh): - return mesh.comm - - class BackendMixedSpace(MixedSpace): - def __init__(self, spaces): - if isinstance(spaces, Sequence): - spaces = tuple(spaces) - else: - spaces = (spaces,) - spaces = tuple_sub(spaces, spaces) - super().__init__( - tuple_sub((space if isinstance(space, FunctionSpaceBase) - else space.dual() for space in iter_sub(spaces)), - spaces)) - self._primal_dual_spaces = tuple(iter_sub(spaces)) - assert len(self._primal_dual_spaces) == len(self._flattened_spaces) - - def new_split(self): - u = [] - for space in self._primal_dual_spaces: - if isinstance(space, FunctionSpaceBase): - u.append(Function(space)) - else: - u.append(Cofunction(space)) - return tuple_sub(u, self._spaces) - - @staticmethod - def _iter_sub_fn(iterable): - def expand(e): - if isinstance(e, (Cofunction, Function)): - space = e.function_space() - if hasattr(space, "num_sub_spaces"): - return tuple(e.sub(i) - for i in range(space.num_sub_spaces())) - return e - - return iter_sub(iterable, expand=expand) - - def mixed_to_split(self, u, u_fn): - if len(self._flattened_spaces) == 1: - u, = tuple(iter_sub(u)) - with u_fn.dat.vec_ro as u_fn_v, u.dat.vec_wo as u_v: - u_fn_v.copy(result=u_v) - else: - for i, u_i in enumerate(self._iter_sub_fn(u)): - with u_fn.sub(i).dat.vec_ro as u_fn_i_v, u_i.dat.vec_wo as u_i_v: # noqa: E501 - u_fn_i_v.copy(result=u_i_v) - - def split_to_mixed(self, u_fn, u): - if len(self._flattened_spaces) == 1: - u, = tuple(iter_sub(u)) - with u_fn.dat.vec_wo as u_fn_v, u.dat.vec_ro as u_v: - u_v.copy(result=u_fn_v) - else: - for i, u_i in enumerate(self._iter_sub_fn(u)): - with u_fn.sub(i).dat.vec_wo as u_fn_i_v, u_i.dat.vec_ro as u_i_v: # noqa: E501 - u_i_v.copy(result=u_fn_i_v) - - def mat(a): - return a.petscmat - - @contextmanager - def vec(u, mode=Access.RW): - attribute_name = {Access.RW: "vec", - Access.READ: "vec_ro", - Access.WRITE: "vec_wo"}[mode] - with getattr(u.dat, attribute_name) as u_v: - yield u_v - - def bc_space(bc): - return bc.function_space() - - def bc_is_homogeneous(bc): - return isinstance(bc.function_arg, ufl.classes.Zero) - - def bc_domain_args(bc): - return (bc.sub_domain,) - - def apply_bcs(u, bcs): - if not isinstance(bcs, Sequence): - bcs = (bcs,) - if len(bcs) > 0 and not isinstance(u.function_space(), type(bcs[0].function_space())): # noqa: E501 - u_bc = u.riesz_representation("l2") +def mesh_comm(mesh): + return mesh.mpi_comm() + + +class BackendMixedSpace(MixedSpace): + @cached_property + def _mixed_to_split_assigner(self): + return FunctionAssigner(list(self._flattened_spaces), + self._mixed_space) + + @cached_property + def _split_to_mixed_assigner(self): + return FunctionAssigner(self._mixed_space, + list(self._flattened_spaces)) + + def mixed_to_split(self, u, u_fn): + if len(self._flattened_spaces) == 1: + u, = tuple(iter_sub(u)) + u.assign(u_fn) else: - u_bc = u - for bc in bcs: - bc.apply(u_bc) -elif backend == "FEniCS": - def mesh_comm(mesh): - return mesh.mpi_comm() - - class BackendMixedSpace(MixedSpace): - @cached_property - def _mixed_to_split_assigner(self): - return FunctionAssigner(list(self._flattened_spaces), - self._mixed_space) - - @cached_property - def _split_to_mixed_assigner(self): - return FunctionAssigner(self._mixed_space, - list(self._flattened_spaces)) - - def mixed_to_split(self, u, u_fn): - if len(self._flattened_spaces) == 1: - u, = tuple(iter_sub(u)) - u.assign(u_fn) - else: - self._mixed_to_split_assigner.assign(list(iter_sub(u)), - u_fn) - - def split_to_mixed(self, u_fn, u): - if len(self._flattened_spaces) == 1: - u, = tuple(iter_sub(u)) - u_fn.assign(u) - else: - self._split_to_mixed_assigner.assign(u_fn, - list(iter_sub(u))) - - def mat(a): - matrix = as_backend_type(a).mat() - if not isinstance(matrix, PETSc.Mat): - raise RuntimeError("PETSc backend required") - return matrix - - @contextmanager - def vec(u, mode=Access.RW): - if isinstance(u, Function): - u = u.vector() - u_v = as_backend_type(u).vec() - if not isinstance(u_v, PETSc.Vec): - raise RuntimeError("PETSc backend required") - - yield u_v - - if mode in {Access.RW, Access.WRITE}: - u.update_ghost_values() - - def bc_space(bc): - return FunctionSpace(bc.function_space()) - - def bc_is_homogeneous(bc): - # A weaker check with FEniCS, as the constant might be modified - return (isinstance(bc.value(), cpp_Constant) - and (bc.value().values() == 0.0).all()) - - def bc_domain_args(bc): - return (bc.sub_domain, bc.method()) - - def apply_bcs(u, bcs): - if not isinstance(bcs, Sequence): - bcs = (bcs,) - for bc in bcs: - bc.apply(u.vector()) -else: - raise ImportError(f"Unexpected backend: {backend}") + self._mixed_to_split_assigner.assign(list(iter_sub(u)), + u_fn) + + def split_to_mixed(self, u_fn, u): + if len(self._flattened_spaces) == 1: + u, = tuple(iter_sub(u)) + u_fn.assign(u) + else: + self._split_to_mixed_assigner.assign(u_fn, + list(iter_sub(u))) + + +def mat(a): + matrix = as_backend_type(a).mat() + if not isinstance(matrix, PETSc.Mat): + raise RuntimeError("PETSc backend required") + return matrix + + +@contextmanager +def vec(u, mode=Access.RW): + if isinstance(u, Function): + u = u.vector() + u_v = as_backend_type(u).vec() + if not isinstance(u_v, PETSc.Vec): + raise RuntimeError("PETSc backend required") + + yield u_v + + if mode in {Access.RW, Access.WRITE}: + u.update_ghost_values() + + +def bc_space(bc): + return FunctionSpace(bc.function_space()) + + +def bc_is_homogeneous(bc): + # Note that the constant might be modified + return (isinstance(bc.value(), cpp_Constant) + and (bc.value().values() == 0.0).all()) + + +def bc_domain_args(bc): + return (bc.sub_domain, bc.method()) + + +def apply_bcs(u, bcs): + if not isinstance(bcs, Sequence): + bcs = (bcs,) + for bc in bcs: + bc.apply(u.vector()) class Nullspace(ABC): diff --git a/tlm_adjoint/fenics/caches.py b/tlm_adjoint/fenics/caches.py index 30339d61..3cfcb598 100644 --- a/tlm_adjoint/fenics/caches.py +++ b/tlm_adjoint/fenics/caches.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -"""This module is used by both the FEniCS and Firedrake backends, and -implements finite element assembly and linear solver data caching. +"""This module implements finite element assembly and linear solver data +caching. """ from .backend import TrialFunction, backend_DirichletBC, backend_Function @@ -311,8 +311,7 @@ def assemble(self, form, *, :arg bcs: Dirichlet boundary conditions. :arg form_compiler_parameters: Form compiler parameters. :arg linear_solver_parameters: Linear solver parameters. Required for - assembly parameters which appear in the linear solver parameters - -- in particular the Firedrake `'mat_type'` parameter. + assembly parameters which appear in the linear solver parameters. :arg replace_map: A :class:`Mapping` defining a map from symbolic variables to values. :returns: A :class:`tuple` `(value_ref, value)`, where `value` is the diff --git a/tlm_adjoint/fenics/equations.py b/tlm_adjoint/fenics/equations.py index 58e039e5..9ae01457 100644 --- a/tlm_adjoint/fenics/equations.py +++ b/tlm_adjoint/fenics/equations.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -"""This module is used by both the FEniCS and Firedrake backends, and -implements finite element calculations. In particular the +"""This module implements finite element calculations. In particular the :class:`EquationSolver` class implements the solution of finite element variational problems. """ @@ -812,13 +811,13 @@ def __init__(self, x, rhs, *args, **kwargs): class DirichletBCApplication(Equation): r"""Represents the application of a Dirichlet boundary condition to a zero - valued backend `Function`. Specifically, with the Firedrake backend this - represents: + valued backend `Function`. Specifically this represents: .. code-block:: python - x.zero() - DirichletBC(x.function_space(), y, *args, **kwargs).apply(x) + x.vector().zero() + DirichletBC(x.function_space(), y, + *args, **kwargs).apply(x.vector()) The forward residual :math:`\mathcal{F}` is defined so that :math:`\partial \mathcal{F} / \partial x` is the identity. diff --git a/tlm_adjoint/fenics/functions.py b/tlm_adjoint/fenics/functions.py index 43cbb66e..cd65c86a 100644 --- a/tlm_adjoint/fenics/functions.py +++ b/tlm_adjoint/fenics/functions.py @@ -1,9 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -"""This module is used by both the FEniCS and Firedrake backends, and includes -functionality for handling :class:`ufl.Coefficient` objects and boundary -conditions. +"""This module includes functionality for handling :class:`ufl.Coefficient` +objects and boundary conditions. """ from .backend import ( @@ -310,26 +309,6 @@ def __init__(self, value=None, *args, name=None, domain=None, space=None, self._tlm_adjoint__var_interface_attrs.d_setitem("static", static) self._tlm_adjoint__var_interface_attrs.d_setitem("cache", cache) - def __new__(cls, value=None, *args, name=None, domain=None, - space_type="primal", shape=None, static=False, cache=None, - **kwargs): - if issubclass(cls, ufl.classes.Coefficient) or domain is None: - return object().__new__(cls) - else: - # For Firedrake - value = constant_value(value, shape) - if space_type not in {"primal", "conjugate", - "dual", "conjugate_dual"}: - raise ValueError("Invalid space type") - if cache is None: - cache = static - F = super().__new__(cls, value, domain=domain) - F.rename(name=name) - F._tlm_adjoint__var_interface_attrs.d_setitem("space_type", space_type) # noqa: E501 - F._tlm_adjoint__var_interface_attrs.d_setitem("static", static) - F._tlm_adjoint__var_interface_attrs.d_setitem("cache", cache) - return F - class Zero: """Mixin for defining a zero-valued variable. Used for zero-valued @@ -364,99 +343,17 @@ def __init__(self, *, name=None, domain=None, space=None, if var_linf_norm(self) != 0.0: raise RuntimeError("ZeroConstant is not zero-valued") - def __new__(cls, *args, shape=None, **kwargs): - return Constant.__new__( - cls, constant_value(shape=shape), *args, - shape=shape, static=True, cache=True, **kwargs) - def assign(self, *args, **kwargs): raise RuntimeError("Cannot call assign method of ZeroConstant") -def as_coefficient(x): - if isinstance(x, ufl.classes.Coefficient): - return x - - # For Firedrake - - if not isinstance(x, backend_Constant): - raise TypeError("Unexpected type") - - if not hasattr(x, "_tlm_adjoint__Coefficient"): - if is_var(x): - space = var_space(x) - else: - if len(x.ufl_shape) == 0: - element = ufl.classes.FiniteElement("R", None, 0) - elif len(x.ufl_shape) == 1: - element = ufl.classes.VectorElement("R", None, 0, - dim=x.ufl_shape[0]) - else: - element = ufl.classes.TensorElement("R", None, 0, - shape=x.ufl_shape) - space = ufl.classes.FunctionSpace(None, element) - - x._tlm_adjoint__Coefficient = ufl.classes.Coefficient(space) - - return x._tlm_adjoint__Coefficient - - -def with_coefficient(expr, x): - x_coeff = as_coefficient(x) - if x_coeff is x: - return expr, {}, {} - else: - # For Firedrake - replace_map = {x: x_coeff} - replace_map_inverse = {x_coeff: x} - return ufl.replace(expr, replace_map), replace_map, replace_map_inverse - - -def with_coefficients(expr): - if isinstance(expr, ufl.classes.Form) \ - and "_tlm_adjoint__form_with_coefficients" in expr._cache: - return expr._cache["_tlm_adjoint__form_with_coefficients"] - - if issubclass(backend_Constant, ufl.classes.Coefficient): - replace_map = {} - else: - # For Firedrake - constants = tuple(sorted(ufl.algorithms.extract_type(expr, backend_Constant), # noqa: E501 - key=lambda c: c.count())) - replace_map = dict(zip(constants, map(as_coefficient, constants))) - replace_map_inverse = {c_coeff: c - for c, c_coeff in replace_map.items()} - - expr_with_coeffs = ufl.replace(expr, replace_map) - if isinstance(expr, ufl.classes.Form): - expr._cache["_tlm_adjoint__form_with_coefficients"] = \ - (expr_with_coeffs, replace_map, replace_map_inverse) - return expr_with_coeffs, replace_map, replace_map_inverse - - def extract_coefficients(expr): """ :returns: Variables on which the supplied :class:`ufl.core.expr.Expr` or :class:`ufl.Form` depends. """ - if isinstance(expr, ufl.classes.Form) \ - and "_tlm_adjoint__form_coefficients" in expr._cache: - return expr._cache["_tlm_adjoint__form_coefficients"] - - if issubclass(backend_Constant, ufl.classes.Coefficient): - cls = (ufl.classes.Coefficient,) - else: - # For Firedrake - cls = (ufl.classes.Coefficient, backend_Constant) - deps = [] - for c in cls: - deps.extend(sorted(ufl.algorithms.extract_type(expr, c), - key=lambda dep: dep.count())) - - if isinstance(expr, ufl.classes.Form): - expr._cache["_tlm_adjoint__form_coefficients"] = deps - return deps + return ufl.algorithms.extract_coefficients(expr) def derivative(expr, x, argument=None, *, @@ -474,17 +371,7 @@ def derivative(expr, x, argument=None, *, if isinstance(argument, ufl.classes.Argument) and argument.number() < arity: # noqa: E501 raise ValueError("Invalid argument") - if isinstance(expr, ufl.classes.Expr): - expr, replace_map, replace_map_inverse = with_coefficient(expr, x) - else: - expr, replace_map, replace_map_inverse = with_coefficients(expr) - - if argument is not None: - argument, _, argument_replace_map_inverse = with_coefficients(argument) - replace_map_inverse.update(argument_replace_map_inverse) - - dexpr = ufl.derivative(expr, replace_map.get(x, x), argument=argument) - return ufl.replace(dexpr, replace_map_inverse) + return ufl.derivative(expr, x, argument=argument) def eliminate_zeros(expr, *, force_non_empty_form=False): diff --git a/tlm_adjoint/firedrake/block_system.py b/tlm_adjoint/firedrake/block_system.py index 460a6958..40cb6181 100644 --- a/tlm_adjoint/firedrake/block_system.py +++ b/tlm_adjoint/firedrake/block_system.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -r"""This module is used by both the FEniCS and Firedrake backends, and -implements solvers for linear systems defined in mixed spaces. +r"""This module implements solvers for linear systems defined in mixed spaces. The :class:`System` class defines the block structure of the linear system, and solves the system using an outer Krylov solver. A custom preconditioner can be @@ -82,22 +81,10 @@ with other tlm_adjoint code, space type warnings may be encountered. """ -# This is the only import from other tlm_adjoint modules that is permitted in -# this module -from .backend import backend - -if backend == "Firedrake": - from firedrake import ( - Cofunction, Constant, DirichletBC, Function, FunctionSpace, - TestFunction, assemble) - from firedrake.functionspaceimpl import WithGeometry as FunctionSpaceBase -elif backend == "FEniCS": - from fenics import ( - Constant, DirichletBC, Function, FunctionSpace, TestFunction, assemble) - from fenics import FunctionAssigner, as_backend_type - from dolfin.cpp.function import Constant as cpp_Constant -else: - raise ImportError(f"Unexpected backend: {backend}") +from firedrake import ( + Cofunction, Constant, DirichletBC, Function, FunctionSpace, TestFunction, + assemble) +from firedrake.functionspaceimpl import WithGeometry as FunctionSpaceBase import petsc4py.PETSc as PETSc import ufl @@ -107,7 +94,7 @@ from collections.abc import Sequence from contextlib import contextmanager from enum import Enum -from functools import cached_property, wraps +from functools import wraps import logging __all__ = \ @@ -327,162 +314,100 @@ def split_to_mixed(self, u_fn, u): raise NotImplementedError -if backend == "Firedrake": - def mesh_comm(mesh): - return mesh.comm +def mesh_comm(mesh): + return mesh.comm - class BackendMixedSpace(MixedSpace): - def __init__(self, spaces): - if isinstance(spaces, Sequence): - spaces = tuple(spaces) - else: - spaces = (spaces,) - spaces = tuple_sub(spaces, spaces) - super().__init__( - tuple_sub((space if isinstance(space, FunctionSpaceBase) - else space.dual() for space in iter_sub(spaces)), - spaces)) - self._primal_dual_spaces = tuple(iter_sub(spaces)) - assert len(self._primal_dual_spaces) == len(self._flattened_spaces) - - def new_split(self): - u = [] - for space in self._primal_dual_spaces: - if isinstance(space, FunctionSpaceBase): - u.append(Function(space)) - else: - u.append(Cofunction(space)) - return tuple_sub(u, self._spaces) - - @staticmethod - def _iter_sub_fn(iterable): - def expand(e): - if isinstance(e, (Cofunction, Function)): - space = e.function_space() - if hasattr(space, "num_sub_spaces"): - return tuple(e.sub(i) - for i in range(space.num_sub_spaces())) - return e - - return iter_sub(iterable, expand=expand) - - def mixed_to_split(self, u, u_fn): - if len(self._flattened_spaces) == 1: - u, = tuple(iter_sub(u)) - with u_fn.dat.vec_ro as u_fn_v, u.dat.vec_wo as u_v: - u_fn_v.copy(result=u_v) - else: - for i, u_i in enumerate(self._iter_sub_fn(u)): - with u_fn.sub(i).dat.vec_ro as u_fn_i_v, u_i.dat.vec_wo as u_i_v: # noqa: E501 - u_fn_i_v.copy(result=u_i_v) - - def split_to_mixed(self, u_fn, u): - if len(self._flattened_spaces) == 1: - u, = tuple(iter_sub(u)) - with u_fn.dat.vec_wo as u_fn_v, u.dat.vec_ro as u_v: - u_v.copy(result=u_fn_v) - else: - for i, u_i in enumerate(self._iter_sub_fn(u)): - with u_fn.sub(i).dat.vec_wo as u_fn_i_v, u_i.dat.vec_ro as u_i_v: # noqa: E501 - u_i_v.copy(result=u_fn_i_v) - def mat(a): - return a.petscmat +class BackendMixedSpace(MixedSpace): + def __init__(self, spaces): + if isinstance(spaces, Sequence): + spaces = tuple(spaces) + else: + spaces = (spaces,) + spaces = tuple_sub(spaces, spaces) + super().__init__( + tuple_sub((space if isinstance(space, FunctionSpaceBase) + else space.dual() for space in iter_sub(spaces)), + spaces)) + self._primal_dual_spaces = tuple(iter_sub(spaces)) + assert len(self._primal_dual_spaces) == len(self._flattened_spaces) - @contextmanager - def vec(u, mode=Access.RW): - attribute_name = {Access.RW: "vec", - Access.READ: "vec_ro", - Access.WRITE: "vec_wo"}[mode] - with getattr(u.dat, attribute_name) as u_v: - yield u_v + def new_split(self): + u = [] + for space in self._primal_dual_spaces: + if isinstance(space, FunctionSpaceBase): + u.append(Function(space)) + else: + u.append(Cofunction(space)) + return tuple_sub(u, self._spaces) - def bc_space(bc): - return bc.function_space() + @staticmethod + def _iter_sub_fn(iterable): + def expand(e): + if isinstance(e, (Cofunction, Function)): + space = e.function_space() + if hasattr(space, "num_sub_spaces"): + return tuple(e.sub(i) + for i in range(space.num_sub_spaces())) + return e - def bc_is_homogeneous(bc): - return isinstance(bc.function_arg, ufl.classes.Zero) + return iter_sub(iterable, expand=expand) - def bc_domain_args(bc): - return (bc.sub_domain,) + def mixed_to_split(self, u, u_fn): + if len(self._flattened_spaces) == 1: + u, = tuple(iter_sub(u)) + with u_fn.dat.vec_ro as u_fn_v, u.dat.vec_wo as u_v: + u_fn_v.copy(result=u_v) + else: + for i, u_i in enumerate(self._iter_sub_fn(u)): + with u_fn.sub(i).dat.vec_ro as u_fn_i_v, u_i.dat.vec_wo as u_i_v: # noqa: E501 + u_fn_i_v.copy(result=u_i_v) - def apply_bcs(u, bcs): - if not isinstance(bcs, Sequence): - bcs = (bcs,) - if len(bcs) > 0 and not isinstance(u.function_space(), type(bcs[0].function_space())): # noqa: E501 - u_bc = u.riesz_representation("l2") + def split_to_mixed(self, u_fn, u): + if len(self._flattened_spaces) == 1: + u, = tuple(iter_sub(u)) + with u_fn.dat.vec_wo as u_fn_v, u.dat.vec_ro as u_v: + u_v.copy(result=u_fn_v) else: - u_bc = u - for bc in bcs: - bc.apply(u_bc) -elif backend == "FEniCS": - def mesh_comm(mesh): - return mesh.mpi_comm() - - class BackendMixedSpace(MixedSpace): - @cached_property - def _mixed_to_split_assigner(self): - return FunctionAssigner(list(self._flattened_spaces), - self._mixed_space) - - @cached_property - def _split_to_mixed_assigner(self): - return FunctionAssigner(self._mixed_space, - list(self._flattened_spaces)) - - def mixed_to_split(self, u, u_fn): - if len(self._flattened_spaces) == 1: - u, = tuple(iter_sub(u)) - u.assign(u_fn) - else: - self._mixed_to_split_assigner.assign(list(iter_sub(u)), - u_fn) + for i, u_i in enumerate(self._iter_sub_fn(u)): + with u_fn.sub(i).dat.vec_wo as u_fn_i_v, u_i.dat.vec_ro as u_i_v: # noqa: E501 + u_i_v.copy(result=u_fn_i_v) - def split_to_mixed(self, u_fn, u): - if len(self._flattened_spaces) == 1: - u, = tuple(iter_sub(u)) - u_fn.assign(u) - else: - self._split_to_mixed_assigner.assign(u_fn, - list(iter_sub(u))) - - def mat(a): - matrix = as_backend_type(a).mat() - if not isinstance(matrix, PETSc.Mat): - raise RuntimeError("PETSc backend required") - return matrix - - @contextmanager - def vec(u, mode=Access.RW): - if isinstance(u, Function): - u = u.vector() - u_v = as_backend_type(u).vec() - if not isinstance(u_v, PETSc.Vec): - raise RuntimeError("PETSc backend required") +def mat(a): + return a.petscmat + + +@contextmanager +def vec(u, mode=Access.RW): + attribute_name = {Access.RW: "vec", + Access.READ: "vec_ro", + Access.WRITE: "vec_wo"}[mode] + with getattr(u.dat, attribute_name) as u_v: yield u_v - if mode in {Access.RW, Access.WRITE}: - u.update_ghost_values() - def bc_space(bc): - return FunctionSpace(bc.function_space()) +def bc_space(bc): + return bc.function_space() - def bc_is_homogeneous(bc): - # A weaker check with FEniCS, as the constant might be modified - return (isinstance(bc.value(), cpp_Constant) - and (bc.value().values() == 0.0).all()) - def bc_domain_args(bc): - return (bc.sub_domain, bc.method()) +def bc_is_homogeneous(bc): + return isinstance(bc.function_arg, ufl.classes.Zero) - def apply_bcs(u, bcs): - if not isinstance(bcs, Sequence): - bcs = (bcs,) - for bc in bcs: - bc.apply(u.vector()) -else: - raise ImportError(f"Unexpected backend: {backend}") + +def bc_domain_args(bc): + return (bc.sub_domain,) + + +def apply_bcs(u, bcs): + if not isinstance(bcs, Sequence): + bcs = (bcs,) + if len(bcs) > 0 and not isinstance(u.function_space(), type(bcs[0].function_space())): # noqa: E501 + u_bc = u.riesz_representation("l2") + else: + u_bc = u + for bc in bcs: + bc.apply(u_bc) class Nullspace(ABC): diff --git a/tlm_adjoint/firedrake/caches.py b/tlm_adjoint/firedrake/caches.py index 30339d61..3cfcb598 100644 --- a/tlm_adjoint/firedrake/caches.py +++ b/tlm_adjoint/firedrake/caches.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -"""This module is used by both the FEniCS and Firedrake backends, and -implements finite element assembly and linear solver data caching. +"""This module implements finite element assembly and linear solver data +caching. """ from .backend import TrialFunction, backend_DirichletBC, backend_Function @@ -311,8 +311,7 @@ def assemble(self, form, *, :arg bcs: Dirichlet boundary conditions. :arg form_compiler_parameters: Form compiler parameters. :arg linear_solver_parameters: Linear solver parameters. Required for - assembly parameters which appear in the linear solver parameters - -- in particular the Firedrake `'mat_type'` parameter. + assembly parameters which appear in the linear solver parameters. :arg replace_map: A :class:`Mapping` defining a map from symbolic variables to values. :returns: A :class:`tuple` `(value_ref, value)`, where `value` is the diff --git a/tlm_adjoint/firedrake/equations.py b/tlm_adjoint/firedrake/equations.py index 58e039e5..80d53eab 100644 --- a/tlm_adjoint/firedrake/equations.py +++ b/tlm_adjoint/firedrake/equations.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -"""This module is used by both the FEniCS and Firedrake backends, and -implements finite element calculations. In particular the +"""This module implements finite element calculations. In particular the :class:`EquationSolver` class implements the solution of finite element variational problems. """ @@ -812,8 +811,7 @@ def __init__(self, x, rhs, *args, **kwargs): class DirichletBCApplication(Equation): r"""Represents the application of a Dirichlet boundary condition to a zero - valued backend `Function`. Specifically, with the Firedrake backend this - represents: + valued backend `Function`. Specifically this represents: .. code-block:: python diff --git a/tlm_adjoint/firedrake/functions.py b/tlm_adjoint/firedrake/functions.py index 43cbb66e..4d6ddcaf 100644 --- a/tlm_adjoint/firedrake/functions.py +++ b/tlm_adjoint/firedrake/functions.py @@ -1,9 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -"""This module is used by both the FEniCS and Firedrake backends, and includes -functionality for handling :class:`ufl.Coefficient` objects and boundary -conditions. +"""This module includes functionality for handling :class:`ufl.Coefficient` +objects and boundary conditions. """ from .backend import ( @@ -313,10 +312,10 @@ def __init__(self, value=None, *args, name=None, domain=None, space=None, def __new__(cls, value=None, *args, name=None, domain=None, space_type="primal", shape=None, static=False, cache=None, **kwargs): - if issubclass(cls, ufl.classes.Coefficient) or domain is None: + assert not issubclass(cls, ufl.classes.Coefficient) + if domain is None: return object().__new__(cls) else: - # For Firedrake value = constant_value(value, shape) if space_type not in {"primal", "conjugate", "dual", "conjugate_dual"}: @@ -377,8 +376,6 @@ def as_coefficient(x): if isinstance(x, ufl.classes.Coefficient): return x - # For Firedrake - if not isinstance(x, backend_Constant): raise TypeError("Unexpected type") @@ -406,7 +403,6 @@ def with_coefficient(expr, x): if x_coeff is x: return expr, {}, {} else: - # For Firedrake replace_map = {x: x_coeff} replace_map_inverse = {x_coeff: x} return ufl.replace(expr, replace_map), replace_map, replace_map_inverse @@ -417,13 +413,10 @@ def with_coefficients(expr): and "_tlm_adjoint__form_with_coefficients" in expr._cache: return expr._cache["_tlm_adjoint__form_with_coefficients"] - if issubclass(backend_Constant, ufl.classes.Coefficient): - replace_map = {} - else: - # For Firedrake - constants = tuple(sorted(ufl.algorithms.extract_type(expr, backend_Constant), # noqa: E501 - key=lambda c: c.count())) - replace_map = dict(zip(constants, map(as_coefficient, constants))) + assert not issubclass(backend_Constant, ufl.classes.Coefficient) + constants = tuple(sorted(ufl.algorithms.extract_type(expr, backend_Constant), # noqa: E501 + key=lambda c: c.count())) + replace_map = dict(zip(constants, map(as_coefficient, constants))) replace_map_inverse = {c_coeff: c for c, c_coeff in replace_map.items()} @@ -444,11 +437,8 @@ def extract_coefficients(expr): and "_tlm_adjoint__form_coefficients" in expr._cache: return expr._cache["_tlm_adjoint__form_coefficients"] - if issubclass(backend_Constant, ufl.classes.Coefficient): - cls = (ufl.classes.Coefficient,) - else: - # For Firedrake - cls = (ufl.classes.Coefficient, backend_Constant) + assert not issubclass(backend_Constant, ufl.classes.Coefficient) + cls = (ufl.classes.Coefficient, backend_Constant) deps = [] for c in cls: deps.extend(sorted(ufl.algorithms.extract_type(expr, c),