Skip to content

Commit

Permalink
Merge pull request #612 from tlm-adjoint/jrmaddison/remove_unused
Browse files Browse the repository at this point in the history
Remove some unused classes
  • Loading branch information
jrmaddison authored Dec 19, 2024
2 parents 53ea515 + f7d351b commit 0168193
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 319 deletions.
4 changes: 1 addition & 3 deletions tests/firedrake/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,7 @@ def test_leaks():
manager._adj_cache.clear()
for block in list(manager._blocks) + [manager._block]:
for eq in block:
if isinstance(eq, PointInterpolation):
del eq._interp
elif isinstance(eq, AdjointActionMarker):
if isinstance(eq, AdjointActionMarker):
del eq._adj_X

gc.collect()
Expand Down
92 changes: 0 additions & 92 deletions tests/firedrake/test_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from .test_base import *

import firedrake
import functools
import numpy as np
import os
Expand Down Expand Up @@ -211,97 +210,6 @@ def forward(a, b):
assert min_order > 1.99


@pytest.mark.firedrake
@pytest.mark.parametrize(
"overlap_type", [(firedrake.DistributedMeshOverlapType.NONE, 0),
pytest.param(
(firedrake.DistributedMeshOverlapType.FACET, 1),
marks=pytest.mark.skipif(DEFAULT_COMM.size == 1,
reason="parallel only")),
pytest.param(
(firedrake.DistributedMeshOverlapType.VERTEX, 1),
marks=pytest.mark.skipif(DEFAULT_COMM.size == 1,
reason="parallel only"))])
@pytest.mark.parametrize("N_x, N_y, N_z", [(2, 2, 2),
(5, 5, 5)])
@pytest.mark.parametrize("c", [-1.5, 1.5])
@seed_test
def test_PointInterpolation(setup_test, test_leaks,
overlap_type,
N_x, N_y, N_z,
c):
mesh = UnitCubeMesh(N_x, N_y, N_z,
distribution_parameters={"partition": True,
"overlap_type": overlap_type})
X = SpatialCoordinate(mesh)
y_space = FunctionSpace(mesh, "Lagrange", 3)
X_coords = np.array([[0.1, 0.1, 0.1],
[0.2, 0.3, 0.4],
[0.9, 0.8, 0.7],
[0.4, 0.2, 0.3]], dtype=backend_RealType)

def forward(y):
X_vals = [Constant(name=f"x_{i:d}")
for i in range(X_coords.shape[0])]
eq = PointInterpolation(X_vals, y, X_coords, tolerance=1.0e-14)
eq.solve()

J = Functional(name="J")
for x in X_vals:
term = Constant()
ExprInterpolation(term, x ** 3).solve()
J.addto(term)
return X_vals, J

y = Function(y_space, name="y", static=True)
if complex_mode:
interpolate_expression(y, pow(X[0], 3) - 1.5 * X[0] * X[1] + c
+ 1.0j * pow(X[0], 2))
else:
interpolate_expression(y, pow(X[0], 3) - 1.5 * X[0] * X[1] + c)

start_manager()
X_vals, J = forward(y)
stop_manager()

def x_ref(x):
if complex_mode:
return x[0] ** 3 - 1.5 * x[0] * x[1] + c + 1.0j * x[0] ** 2
else:
return x[0] ** 3 - 1.5 * x[0] * x[1] + c

x_error_norm = 0.0
assert len(X_vals) == len(X_coords)
for x, x_coord in zip(X_vals, X_coords):
x_error_norm = max(x_error_norm,
abs(var_scalar_value(x) - x_ref(x_coord)))
info(f"Error norm = {x_error_norm:.16e}")
assert x_error_norm < 1.0e-13

J_val = J.value

dJ = compute_gradient(J, y)

def forward_J(y):
return forward(y)[1]

min_order = taylor_test(forward_J, y, J_val=J_val, dJ=dJ)
assert min_order > 1.99

ddJ = Hessian(forward_J)
min_order = taylor_test(forward_J, y, J_val=J_val, ddJ=ddJ)
assert min_order > 2.99

min_order = taylor_test_tlm(forward_J, y, tlm_order=1)
assert min_order > 1.99

min_order = taylor_test_tlm_adjoint(forward_J, y, adjoint_order=1)
assert min_order > 1.99

min_order = taylor_test_tlm_adjoint(forward_J, y, adjoint_order=2)
assert min_order > 1.99


@pytest.mark.firedrake
@pytest.mark.parametrize("ExprAssignment_cls", [ExprAssignment,
ExprInterpolation])
Expand Down
98 changes: 4 additions & 94 deletions tlm_adjoint/equations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .interface import (
Packed, check_space_types, check_space_types_conjugate_dual,
check_space_types_dual, packed, var_assign, var_axpy, var_axpy_conjugate,
var_dot, var_dtype, var_get_values, var_id, var_is_scalar, var_inner,
check_space_types, check_space_types_conjugate_dual,
check_space_types_dual, var_assign, var_axpy, var_axpy_conjugate, var_dot,
var_dtype, var_get_values, var_id, var_is_scalar, var_inner,
var_local_size, var_new_conjugate_dual, var_replacement, var_scalar_value,
var_set_values, var_zero)

Expand All @@ -22,8 +22,7 @@
"DotProductRHS",
"DotProduct",
"InnerProductRHS",
"InnerProduct",
"MatrixActionRHS"
"InnerProduct"
]


Expand Down Expand Up @@ -272,95 +271,6 @@ def __init__(self, x, y, z, *, alpha=1.0, M=None):
super().__init__(x, InnerProductRHS(y, z, alpha=alpha, M=M))


class MatrixActionRHS(RHS):
"""Represents a right-hand-side term
.. math::
A x.
:arg A: A :class:`tlm_adjoint.linear_equation.Matrix` defining :math:`A`.
:arg x: A variable or a :class:`Sequence` of variables defining :math:`x`.
"""

def __init__(self, A, X):
X_packed = Packed(X)
X = tuple(X_packed)
if len(set(map(var_id, X))) != len(X):
raise ValueError("Invalid dependency")

A_nl_deps = A.nonlinear_dependencies()
if len(A_nl_deps) == 0:
x_indices = {i: i for i in range(len(X))}
super().__init__(X, nl_deps=[])
else:
nl_deps = list(A_nl_deps)
nl_dep_ids = {var_id(dep): i for i, dep in enumerate(nl_deps)}
x_indices = {}
for i, x in enumerate(X):
x_id = var_id(x)
if x_id not in nl_dep_ids:
nl_deps.append(x)
nl_dep_ids[x_id] = len(nl_deps) - 1
x_indices[nl_dep_ids[x_id]] = i
super().__init__(nl_deps, nl_deps=nl_deps)

self._packed = X_packed.mapped(lambda x: None)
self._A = A
self._x_indices = x_indices

self.add_referrer(A)

def drop_references(self):
super().drop_references()
self._A = self._A._weak_alias

def _unpack(self, obj):
return self._packed.unpack(obj)

def add_forward(self, B, deps):
B = packed(B)
X = tuple(deps[j] for j in self._x_indices)
self._A.forward_action(deps[:len(self._A.nonlinear_dependencies())],
self._unpack(X),
self._unpack(B),
method="add")

def subtract_adjoint_derivative_action(self, nl_deps, dep_index, adj_X, b):
adj_X = packed(adj_X)
N_A_nl_deps = len(self._A.nonlinear_dependencies())
if dep_index < N_A_nl_deps:
X = tuple(nl_deps[j] for j in self._x_indices)
self._A.adjoint_derivative_action(
nl_deps[:N_A_nl_deps], dep_index,
self._unpack(X),
self._unpack(adj_X),
b, method="sub")

if dep_index in self._x_indices:
self._A.adjoint_action(nl_deps[:N_A_nl_deps],
self._unpack(adj_X),
b, b_index=self._x_indices[dep_index],
method="sub")

def tangent_linear_rhs(self, tlm_map):
deps = self.dependencies()
N_A_nl_deps = len(self._A.nonlinear_dependencies())

X = tuple(deps[j] for j in self._x_indices)
tlm_X = tuple(tlm_map[x] for x in X)
tlm_B = [MatrixActionRHS(self._A, self._unpack(tlm_X))]

if N_A_nl_deps > 0:
tlm_b = self._A.tangent_linear_rhs(tlm_map, X)
if tlm_b is None:
pass
else:
tlm_B.extend(packed(tlm_b))

return tlm_B


class DotProductRHS(RHS):
r"""Represents a right-hand-side term
Expand Down
136 changes: 6 additions & 130 deletions tlm_adjoint/firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,25 @@
"""

from .backend import (
FunctionSpace, Interpolator, TestFunction, VertexOnlyMesh,
backend_Cofunction, backend_Constant, backend_Function)
Interpolator, backend_Cofunction, backend_Constant, backend_Function)
from ..interface import (
check_space_type, check_space_types, comm_dup_cached, packed, space_new,
var_assign, var_assign_conjugate, var_axpy, var_axpy_conjugate, var_comm,
var_copy_conjugate, var_id, var_inner, var_is_scalar, var_new,
var_new_conjugate, var_new_conjugate_dual, var_replacement,
var_scalar_value, var_zero)
check_space_types, var_assign, var_assign_conjugate, var_axpy,
var_axpy_conjugate, var_copy_conjugate, var_id, var_inner, var_new,
var_new_conjugate, var_new_conjugate_dual, var_replacement, var_zero)

from ..equation import Equation, ZeroAssignment
from ..equation import ZeroAssignment
from ..manager import manager_disabled

from .expr import (
ExprEquation, derivative, eliminate_zeros, expr_zero, extract_dependencies,
iter_expr)
from .variables import ReplacementConstant

import itertools
import numpy as np
import ufl

__all__ = \
[
"ExprInterpolation",
"PointInterpolation"
"ExprInterpolation"
]


Expand Down Expand Up @@ -139,121 +133,3 @@ def tangent_linear(self, tlm_map):
return ZeroAssignment(tlm_map[x])
else:
return ExprInterpolation(tlm_map[x], tlm_rhs)


def vmesh_coords_map(vmesh, X_coords):
comm = comm_dup_cached(vmesh.comm)
N, _ = X_coords.shape

vmesh_coords = vmesh.coordinates.dat.data_ro
Nm, _ = vmesh_coords.shape

vmesh_coords_indices = {tuple(vmesh_coords[i, :]): i for i in range(Nm)}
vmesh_coords_map = np.full(Nm, -1, dtype=np.int_)
for i in range(N):
key = tuple(X_coords[i, :])
if key in vmesh_coords_indices:
vmesh_coords_map[vmesh_coords_indices[key]] = i
if (vmesh_coords_map < 0).any():
raise RuntimeError("Failed to find vertex map")

vmesh_coords_map = comm.allgather(vmesh_coords_map)
if len(tuple(itertools.chain(*vmesh_coords_map))) != N:
raise RuntimeError("Failed to find vertex map")

return vmesh_coords_map


class PointInterpolation(Equation):
r"""Represents interpolation of a scalar-valued function at given points.
The forward residual :math:`\mathcal{F}` is defined so that :math:`\partial
\mathcal{F} / \partial x` is the identity.
:arg X: A scalar variable, or a :class:`Sequence` of scalar variables,
defining the forward solution.
:arg y: A scalar-valued :class:`firedrake.function.Function` to
interpolate.
:arg X_coords: A :class:`numpy.ndarray` defining the coordinates at which
to interpolate `y`. Shape is `(n, d)` where `n` is the number of
interpolation points and `d` is the geometric dimension. Ignored if `P`
is supplied.
:arg tolerance: :class:`firedrake.mesh.VertexOnlyMesh` tolerance.
"""

def __init__(self, X, y, X_coords=None, *, tolerance=None,
_interp=None):
X = packed(X)
for x in X:
check_space_type(x, "primal")
if not var_is_scalar(x):
raise ValueError("Solution must be a scalar variable, or a "
"Sequence of scalar variables")
check_space_type(y, "primal")

if X_coords is None:
if _interp is None:
raise TypeError("X_coords required")
else:
if len(X) != X_coords.shape[0]:
raise ValueError("Invalid number of variables")
if not isinstance(y, backend_Function):
raise TypeError("y must be a Function")
if len(y.ufl_shape) > 0:
raise ValueError("y must be a scalar-valued Function")

interp = _interp
if interp is None:
y_space = y.function_space()
vmesh = VertexOnlyMesh(y_space.mesh(), X_coords,
tolerance=tolerance)
vspace = FunctionSpace(vmesh, "Discontinuous Lagrange", 0)
interp = Interpolator(TestFunction(y_space), vspace)
if not hasattr(interp, "_tlm_adjoint__vmesh_coords_map"):
interp._tlm_adjoint__vmesh_coords_map = vmesh_coords_map(vmesh, X_coords) # noqa: E501

super().__init__(X, list(X) + [y], nl_deps=[], ic=False, adj_ic=False)
self._interp = interp

def forward_solve(self, X, deps=None):
y = (self.dependencies() if deps is None else deps)[-1]

Xm = space_new(self._interp.V)
self._interp._interpolate(y, output=Xm)

X_values = var_comm(Xm).allgather(Xm.dat.data_ro)
vmesh_coords_map = self._interp._tlm_adjoint__vmesh_coords_map
for x_val, index in zip(itertools.chain(*X_values),
itertools.chain(*vmesh_coords_map)):
X[index].assign(x_val)

def adjoint_derivative_action(self, nl_deps, dep_index, adj_X):
if dep_index != len(self.X()):
raise ValueError("Unexpected dep_index")

adj_Xm = space_new(self._interp.V.dual())

vmesh_coords_map = self._interp._tlm_adjoint__vmesh_coords_map
rank = var_comm(adj_Xm).rank
# This line must be outside the loop to avoid deadlocks
adj_Xm_data = adj_Xm.dat.data
for i, j in enumerate(vmesh_coords_map[rank]):
adj_Xm_data[i] = var_scalar_value(adj_X[j])

F = var_new_conjugate_dual(self.dependencies()[-1])
self._interp._interpolate(adj_Xm, transpose=True, output=F)
return (-1.0, F)

def adjoint_jacobian_solve(self, adj_X, nl_deps, B):
return B

def tangent_linear(self, tlm_map):
X = self.X()
y = self.dependencies()[-1]

tlm_y = tlm_map[y]
if tlm_y is None:
return ZeroAssignment([tlm_map[x] for x in X])
else:
return PointInterpolation([tlm_map[x] for x in X], tlm_y,
_interp=self._interp)

0 comments on commit 0168193

Please sign in to comment.