From eb30cc7d1805a2b9612430e784ff04d9ec302848 Mon Sep 17 00:00:00 2001 From: Iglesia Dolci Date: Sun, 8 Dec 2024 10:29:59 +0000 Subject: [PATCH] Docs --- .../adjoint/ensemble_reduced_functional.py | 37 +++++++++++++------ 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/firedrake/adjoint/ensemble_reduced_functional.py b/firedrake/adjoint/ensemble_reduced_functional.py index 0ee4fd9b90..38145c5cd8 100644 --- a/firedrake/adjoint/ensemble_reduced_functional.py +++ b/firedrake/adjoint/ensemble_reduced_functional.py @@ -28,7 +28,9 @@ class EnsembleReducedFunctional(ReducedFunctional): operation is employed to sum the functionals and their gradients over an ensemble communicator. - If gather_functional is present, then all the values of J are communicated to all ensemble ranks, and passed in a list to gather_functional, which is a reduced functional that expects a list of that size of the relevant types. + If gather_functional is present, then all the values of J are communicated to all ensemble + ranks, and passed in a list to gather_functional, which is a reduced functional that expects + a list of that size of the relevant types. Parameters ---------- @@ -45,6 +47,24 @@ class EnsembleReducedFunctional(ReducedFunctional): ``Ensemble.ensemble comm``. gather_functional : An instance of the :class:`pyadjoint.ReducedFunctional`. that takes in all of the Js. + derivative_components : list of int + The indices of the controls that the derivative should be computed with respect to. + If present, it overwrites ``derivative_cb_pre`` and ``derivative_cb_post``. + scale : float + A scaling factor applied to the functional and its gradient(with respect to the control). + tape : pyadjoint.Tape + A tape object that the reduced functional will use to evaluate the functional and + its gradient (or gradients). + eval_cb_pre : callable + Callback function before evaluating the functional. Input is a list of Controls. + derivative_cb_pre : callable + Callback function before evaluating derivatives. Input is a list of derivatives. + Should return a list of Controls (usually the same list as the input) to be passed + to :func:`pyadjoint.compute_gradient`. + derivative_cb_post : callable + Callback function after evaluating derivatives. Inputs are the functional, the derivative, + and the controls. All of them are the checkpointed versions. Should return a list of + derivatives (usually the same list as the input)to be returned from ``self.derivative``. See Also @@ -59,22 +79,17 @@ class EnsembleReducedFunctional(ReducedFunctional): `_. """ def __init__(self, J, control, ensemble, scatter_control=True, - gather_functional=None, - derivative_components=None, - scale=1.0, tape=None, - eval_cb_pre=lambda *args: None, + gather_functional=None, derivative_components=None, + scale=1.0, tape=None, eval_cb_pre=lambda *args: None, eval_cb_post=lambda *args: None, derivative_cb_pre=lambda controls: controls, - derivative_cb_post=lambda checkpoint, derivative_components, controls: derivative_components, - hessian_cb_pre=lambda *args: None, - hessian_cb_post=lambda *args: None): + derivative_cb_post=lambda checkpoint, derivative_components, controls: derivative_components + ): super(EnsembleReducedFunctional, self).__init__( J, control, derivative_components=derivative_components, scale=scale, tape=tape, eval_cb_pre=eval_cb_pre, eval_cb_post=eval_cb_post, derivative_cb_pre=derivative_cb_pre, - derivative_cb_post=derivative_cb_post, - hessian_cb_pre=hessian_cb_pre, - hessian_cb_post=hessian_cb_post) + derivative_cb_post=derivative_cb_post) self.ensemble = ensemble self.scatter_control = scatter_control self.gather_functional = gather_functional