Skip to content

Commit

Permalink
Remove backend dependent code
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmaddison committed Oct 16, 2023
1 parent b84c9f6 commit 1061c50
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 487 deletions.
250 changes: 77 additions & 173 deletions tlm_adjoint/fenics/block_system.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

r"""This module is used by both the FEniCS and Firedrake backends, and
implements solvers for linear systems defined in mixed spaces.
r"""This module implements solvers for linear systems defined in mixed spaces.
The :class:`System` class defines the block structure of the linear system, and
solves the system using an outer Krylov solver. A custom preconditioner can be
Expand Down Expand Up @@ -82,22 +81,11 @@
with other tlm_adjoint code, space type warnings may be encountered.
"""

# This is the only import from other tlm_adjoint modules that is permitted in
# this module
from .backend import backend

if backend == "Firedrake":
from firedrake import (
Cofunction, Constant, DirichletBC, Function, FunctionSpace,
TestFunction, assemble)
from firedrake.functionspaceimpl import WithGeometry as FunctionSpaceBase
elif backend == "FEniCS":
from fenics import (
Constant, DirichletBC, Function, FunctionSpace, TestFunction, assemble)
from fenics import FunctionAssigner, as_backend_type
from dolfin.cpp.function import Constant as cpp_Constant
else:
raise ImportError(f"Unexpected backend: {backend}")
from fenics import (
Constant, DirichletBC, Function, FunctionSpace, TestFunction, assemble)
from fenics import FunctionAssigner, as_backend_type
from dolfin.cpp.function import Constant as cpp_Constant


import petsc4py.PETSc as PETSc
import ufl
Expand Down Expand Up @@ -327,162 +315,78 @@ def split_to_mixed(self, u_fn, u):
raise NotImplementedError


if backend == "Firedrake":
def mesh_comm(mesh):
return mesh.comm

class BackendMixedSpace(MixedSpace):
def __init__(self, spaces):
if isinstance(spaces, Sequence):
spaces = tuple(spaces)
else:
spaces = (spaces,)
spaces = tuple_sub(spaces, spaces)
super().__init__(
tuple_sub((space if isinstance(space, FunctionSpaceBase)
else space.dual() for space in iter_sub(spaces)),
spaces))
self._primal_dual_spaces = tuple(iter_sub(spaces))
assert len(self._primal_dual_spaces) == len(self._flattened_spaces)

def new_split(self):
u = []
for space in self._primal_dual_spaces:
if isinstance(space, FunctionSpaceBase):
u.append(Function(space))
else:
u.append(Cofunction(space))
return tuple_sub(u, self._spaces)

@staticmethod
def _iter_sub_fn(iterable):
def expand(e):
if isinstance(e, (Cofunction, Function)):
space = e.function_space()
if hasattr(space, "num_sub_spaces"):
return tuple(e.sub(i)
for i in range(space.num_sub_spaces()))
return e

return iter_sub(iterable, expand=expand)

def mixed_to_split(self, u, u_fn):
if len(self._flattened_spaces) == 1:
u, = tuple(iter_sub(u))
with u_fn.dat.vec_ro as u_fn_v, u.dat.vec_wo as u_v:
u_fn_v.copy(result=u_v)
else:
for i, u_i in enumerate(self._iter_sub_fn(u)):
with u_fn.sub(i).dat.vec_ro as u_fn_i_v, u_i.dat.vec_wo as u_i_v: # noqa: E501
u_fn_i_v.copy(result=u_i_v)

def split_to_mixed(self, u_fn, u):
if len(self._flattened_spaces) == 1:
u, = tuple(iter_sub(u))
with u_fn.dat.vec_wo as u_fn_v, u.dat.vec_ro as u_v:
u_v.copy(result=u_fn_v)
else:
for i, u_i in enumerate(self._iter_sub_fn(u)):
with u_fn.sub(i).dat.vec_wo as u_fn_i_v, u_i.dat.vec_ro as u_i_v: # noqa: E501
u_i_v.copy(result=u_fn_i_v)

def mat(a):
return a.petscmat

@contextmanager
def vec(u, mode=Access.RW):
attribute_name = {Access.RW: "vec",
Access.READ: "vec_ro",
Access.WRITE: "vec_wo"}[mode]
with getattr(u.dat, attribute_name) as u_v:
yield u_v

def bc_space(bc):
return bc.function_space()

def bc_is_homogeneous(bc):
return isinstance(bc.function_arg, ufl.classes.Zero)

def bc_domain_args(bc):
return (bc.sub_domain,)

def apply_bcs(u, bcs):
if not isinstance(bcs, Sequence):
bcs = (bcs,)
if len(bcs) > 0 and not isinstance(u.function_space(), type(bcs[0].function_space())): # noqa: E501
u_bc = u.riesz_representation("l2")
def mesh_comm(mesh):
return mesh.mpi_comm()


class BackendMixedSpace(MixedSpace):
@cached_property
def _mixed_to_split_assigner(self):
return FunctionAssigner(list(self._flattened_spaces),
self._mixed_space)

@cached_property
def _split_to_mixed_assigner(self):
return FunctionAssigner(self._mixed_space,
list(self._flattened_spaces))

def mixed_to_split(self, u, u_fn):
if len(self._flattened_spaces) == 1:
u, = tuple(iter_sub(u))
u.assign(u_fn)
else:
u_bc = u
for bc in bcs:
bc.apply(u_bc)
elif backend == "FEniCS":
def mesh_comm(mesh):
return mesh.mpi_comm()

class BackendMixedSpace(MixedSpace):
@cached_property
def _mixed_to_split_assigner(self):
return FunctionAssigner(list(self._flattened_spaces),
self._mixed_space)

@cached_property
def _split_to_mixed_assigner(self):
return FunctionAssigner(self._mixed_space,
list(self._flattened_spaces))

def mixed_to_split(self, u, u_fn):
if len(self._flattened_spaces) == 1:
u, = tuple(iter_sub(u))
u.assign(u_fn)
else:
self._mixed_to_split_assigner.assign(list(iter_sub(u)),
u_fn)

def split_to_mixed(self, u_fn, u):
if len(self._flattened_spaces) == 1:
u, = tuple(iter_sub(u))
u_fn.assign(u)
else:
self._split_to_mixed_assigner.assign(u_fn,
list(iter_sub(u)))

def mat(a):
matrix = as_backend_type(a).mat()
if not isinstance(matrix, PETSc.Mat):
raise RuntimeError("PETSc backend required")
return matrix

@contextmanager
def vec(u, mode=Access.RW):
if isinstance(u, Function):
u = u.vector()
u_v = as_backend_type(u).vec()
if not isinstance(u_v, PETSc.Vec):
raise RuntimeError("PETSc backend required")

yield u_v

if mode in {Access.RW, Access.WRITE}:
u.update_ghost_values()

def bc_space(bc):
return FunctionSpace(bc.function_space())

def bc_is_homogeneous(bc):
# A weaker check with FEniCS, as the constant might be modified
return (isinstance(bc.value(), cpp_Constant)
and (bc.value().values() == 0.0).all())

def bc_domain_args(bc):
return (bc.sub_domain, bc.method())

def apply_bcs(u, bcs):
if not isinstance(bcs, Sequence):
bcs = (bcs,)
for bc in bcs:
bc.apply(u.vector())
else:
raise ImportError(f"Unexpected backend: {backend}")
self._mixed_to_split_assigner.assign(list(iter_sub(u)),
u_fn)

def split_to_mixed(self, u_fn, u):
if len(self._flattened_spaces) == 1:
u, = tuple(iter_sub(u))
u_fn.assign(u)
else:
self._split_to_mixed_assigner.assign(u_fn,
list(iter_sub(u)))


def mat(a):
matrix = as_backend_type(a).mat()
if not isinstance(matrix, PETSc.Mat):
raise RuntimeError("PETSc backend required")
return matrix


@contextmanager
def vec(u, mode=Access.RW):
if isinstance(u, Function):
u = u.vector()
u_v = as_backend_type(u).vec()
if not isinstance(u_v, PETSc.Vec):
raise RuntimeError("PETSc backend required")

yield u_v

if mode in {Access.RW, Access.WRITE}:
u.update_ghost_values()


def bc_space(bc):
return FunctionSpace(bc.function_space())


def bc_is_homogeneous(bc):
# Note that the constant might be modified
return (isinstance(bc.value(), cpp_Constant)
and (bc.value().values() == 0.0).all())


def bc_domain_args(bc):
return (bc.sub_domain, bc.method())


def apply_bcs(u, bcs):
if not isinstance(bcs, Sequence):
bcs = (bcs,)
for bc in bcs:
bc.apply(u.vector())


class Nullspace(ABC):
Expand Down
7 changes: 3 additions & 4 deletions tlm_adjoint/fenics/caches.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""This module is used by both the FEniCS and Firedrake backends, and
implements finite element assembly and linear solver data caching.
"""This module implements finite element assembly and linear solver data
caching.
"""

from .backend import TrialFunction, backend_DirichletBC, backend_Function
Expand Down Expand Up @@ -311,8 +311,7 @@ def assemble(self, form, *,
:arg bcs: Dirichlet boundary conditions.
:arg form_compiler_parameters: Form compiler parameters.
:arg linear_solver_parameters: Linear solver parameters. Required for
assembly parameters which appear in the linear solver parameters
-- in particular the Firedrake `'mat_type'` parameter.
assembly parameters which appear in the linear solver parameters.
:arg replace_map: A :class:`Mapping` defining a map from symbolic
variables to values.
:returns: A :class:`tuple` `(value_ref, value)`, where `value` is the
Expand Down
11 changes: 5 additions & 6 deletions tlm_adjoint/fenics/equations.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""This module is used by both the FEniCS and Firedrake backends, and
implements finite element calculations. In particular the
"""This module implements finite element calculations. In particular the
:class:`EquationSolver` class implements the solution of finite element
variational problems.
"""
Expand Down Expand Up @@ -812,13 +811,13 @@ def __init__(self, x, rhs, *args, **kwargs):

class DirichletBCApplication(Equation):
r"""Represents the application of a Dirichlet boundary condition to a zero
valued backend `Function`. Specifically, with the Firedrake backend this
represents:
valued backend `Function`. Specifically this represents:
.. code-block:: python
x.zero()
DirichletBC(x.function_space(), y, *args, **kwargs).apply(x)
x.vector().zero()
DirichletBC(x.function_space(), y,
*args, **kwargs).apply(x.vector())
The forward residual :math:`\mathcal{F}` is defined so that :math:`\partial
\mathcal{F} / \partial x` is the identity.
Expand Down
Loading

0 comments on commit 1061c50

Please sign in to comment.