diff --git a/firedrake/adjoint/all_at_once_reduced_functional.py b/firedrake/adjoint/all_at_once_reduced_functional.py index 42f3f90965..437bd4623f 100644 --- a/firedrake/adjoint/all_at_once_reduced_functional.py +++ b/firedrake/adjoint/all_at_once_reduced_functional.py @@ -5,7 +5,6 @@ from typing import Callable, Optional from contextlib import contextmanager from mpi4py import MPI -from firedrake.petsc import PETSc __all__ = ['AllAtOnceReducedFunctional'] @@ -77,6 +76,20 @@ def _ad_sub(left, right): return result +def _intermediate_options(final_options): + """ + Options set for the intermediate stages of a chain of ReducedFunctionals + + Takes all elements of the final_options except riesz_representation, + which is set to prevent returning derivatives to the primal space. + """ + return { + 'riesz_representation': None, + **{k: v for k, v in final_options.items() + if (k != 'riesz_representation')} + } + + class AllAtOnceReducedFunctional(ReducedFunctional): """ReducedFunctional for 4DVar data assimilation. @@ -359,11 +372,7 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}): # chaining ReducedFunctionals means we need to pass Cofunctions not Functions options = options or {} - intermediate_options = { - 'riesz_representation': 'l2', - **{k: v for k, v in options.items() - if (k != 'riesz_representation')} - } + intermediate_options = _intermediate_options(options) # evaluate first forward model, which contributes to previous chunk sderiv0 = self.stages[0].derivative( @@ -627,7 +636,6 @@ def __next__(self): stage = StrongObservationStage(control, self.aaorf) self._prev_stage = stage - return stage, self.ctx @@ -903,11 +911,7 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}, # chaining ReducedFunctionals means we need to pass Cofunctions not Functions options = options or {} - intermediate_options = { - 'riesz_representation': None, - **{k: v for k, v in options.items() - if (k != 'riesz_representation')} - } + intermediate_options = _intermediate_options(options) if (rftype is None) or (rftype == 'model'): # derivative of reduction @@ -922,8 +926,16 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}, dm_forward = self.forward_model.derivative(adj_input=dm_errors[0], options=options) + sentinel = -12345 + riesz_map = options.get('riesz_representation', sentinel) derivatives.append(dm_forward) - derivatives.append(dm_errors[1].riesz_representation()) + if riesz_map != sentinel: + if riesz_map is None: + derivatives.append(dm_errors[1]) + else: + derivatives.append(dm_errors[1].riesz_representation(riesz_map)) + else: + derivatives.append(dm_errors[1].riesz_representation()) if (rftype is None) or (rftype == 'obs'): # derivative of reduction