Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add all reduced functional arguments. #3908

Merged
merged 6 commits into from
Dec 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions firedrake/adjoint/ensemble_reduced_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class EnsembleReducedFunctional(ReducedFunctional):
operation is employed to sum the functionals and their gradients over an ensemble
communicator.

If gather_functional is present, then all the values of J are communicated to all ensemble ranks, and passed in a list to gather_functional, which is a reduced functional that expects a list of that size of the relevant types.
If gather_functional is present, then all the values of J are communicated to all ensemble
ranks, and passed in a list to gather_functional, which is a reduced functional that expects
a list of that size of the relevant types.

Parameters
----------
Expand All @@ -45,7 +47,33 @@ class EnsembleReducedFunctional(ReducedFunctional):
``Ensemble.ensemble comm``.
gather_functional : An instance of the :class:`pyadjoint.ReducedFunctional`.
that takes in all of the Js.

derivative_components : list of int
The indices of the controls that the derivative should be computed with respect to.
If present, it overwrites ``derivative_cb_pre`` and ``derivative_cb_post``.
scale : float
A scaling factor applied to the functional and its gradient(with respect to the control).
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
tape : pyadjoint.Tape
A tape object that the reduced functional will use to evaluate the functional and
its gradients (or derivatives).
eval_cb_pre : :func:
Copy link
Contributor Author

@Ig-dolci Ig-dolci Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if eval_cb_pos is in the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Callback function before evaluating the functional. Input is a list of Controls.
eval_cb_pos : :func:
Callback function after evaluating the functional. Inputs are the functional value
and a list of Controls.
derivative_cb_pre : :func:
Callback function before evaluating gradients (or derivatives). Input is a list of
gradients (or derivatives). Should return a list of Controls (usually the same list as
the input) to be passed to :func:`pyadjoint.compute_gradient`.
derivative_cb_post : :func:
Callback function after evaluating derivatives. Inputs are the functional, a list of
gradients (or derivatives), and controls. All of them are the checkpointed versions.
Should return a list of gradients (or derivatives) (usually the same list as the input)
to be returned from ``self.derivative``.
hessian_cb_pre : :func:
Callback function before evaluating the Hessian. Input is a list of Controls.
hessian_cb_post : :func:
Callback function after evaluating the Hessian. Inputs are the functional, a list of
Hessian, and controls.

See Also
--------
Expand All @@ -59,8 +87,19 @@ class EnsembleReducedFunctional(ReducedFunctional):
<https://www.firedrakeproject.org/parallelism.html#id8>`_.
"""
def __init__(self, J, control, ensemble, scatter_control=True,
gather_functional=None):
super(EnsembleReducedFunctional, self).__init__(J, control)
gather_functional=None, derivative_components=None,
scale=1.0, tape=None, eval_cb_pre=lambda *args: None,
eval_cb_post=lambda *args: None,
derivative_cb_pre=lambda controls: controls,
derivative_cb_post=lambda checkpoint, derivative_components, controls: derivative_components,
hessian_cb_pre=lambda *args: None, hessian_cb_post=lambda *args: None):
super(EnsembleReducedFunctional, self).__init__(
J, control, derivative_components=derivative_components,
scale=scale, tape=tape, eval_cb_pre=eval_cb_pre,
eval_cb_post=eval_cb_post, derivative_cb_pre=derivative_cb_pre,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add hessian args

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

derivative_cb_post=derivative_cb_post,
hessian_cb_pre=hessian_cb_pre, hessian_cb_post=hessian_cb_post)

self.ensemble = ensemble
self.scatter_control = scatter_control
self.gather_functional = gather_functional
Expand Down
Loading