From 3415139a2644ace782918e9f9b15ca821c028c4d Mon Sep 17 00:00:00 2001 From: Daiane Iglesia Dolci <63597005+Ig-dolci@users.noreply.github.com> Date: Thu, 28 Nov 2024 09:43:54 +0000 Subject: [PATCH] Return AdjFloat of an AdjFloat summation (#181) * Return AdjFloat of an AdjFloat summation and check the control update in rf --- pyadjoint/adjfloat.py | 3 ++- pyadjoint/reduced_functional.py | 16 ++++++++++++++++ pyadjoint/tape.py | 2 +- tests/firedrake_adjoint/test_burgers_newton.py | 5 +++-- .../test_external_modification.py | 4 ++-- tests/firedrake_adjoint/test_tlm.py | 4 ++-- 6 files changed, 26 insertions(+), 8 deletions(-) diff --git a/pyadjoint/adjfloat.py b/pyadjoint/adjfloat.py index 790f8feb..d471293d 100644 --- a/pyadjoint/adjfloat.py +++ b/pyadjoint/adjfloat.py @@ -342,7 +342,8 @@ def __init__(self, *args): self.add_dependency(dep) def recompute_component(self, inputs, block_variable, idx, prepared): - return self.operator(*(term.saved_output for term in self.terms)) + output = self.operator(*(term.saved_output for term in self.terms)) + return self._outputs[0].saved_output._ad_convert_type(output) def __str__(self): return f"{self.terms[0]} {self.symbol} {self.terms[1]}" diff --git a/pyadjoint/reduced_functional.py b/pyadjoint/reduced_functional.py index 2ea0aad8..88186fe6 100644 --- a/pyadjoint/reduced_functional.py +++ b/pyadjoint/reduced_functional.py @@ -2,6 +2,7 @@ from .enlisting import Enlist from .tape import get_working_tape, stop_annotating, no_annotations from .overloaded_type import OverloadedType, create_overloaded_object +from .adjfloat import AdjFloat def _get_extract_derivative_components(derivative_components): @@ -196,6 +197,21 @@ def __call__(self, values): if len(values) != len(self.controls): raise ValueError("values should be a list of same length as controls.") + for i, value in enumerate(values): + control_type = type(self.controls[i].control) + if isinstance(value, (int, float)) and control_type is AdjFloat: + value = self.controls[i].control._ad_convert_type(value) + elif not isinstance(value, control_type): + if len(values) == 1: + raise TypeError( + "Control value must be an `OverloadedType` object with the same " + f"type as the control, which is {control_type}" + ) + else: + raise TypeError( + f"The control at index {i} must be an `OverloadedType` object " + f"with the same type as the control, which is {control_type}" + ) # Call callback. self.eval_cb_pre(self.controls.delist(values)) diff --git a/pyadjoint/tape.py b/pyadjoint/tape.py index f2c908fc..b3ac6e04 100644 --- a/pyadjoint/tape.py +++ b/pyadjoint/tape.py @@ -29,7 +29,7 @@ def continue_annotation(): class set_working_tape(ContextDecorator): """Set a new tape as the working tape. - + This class can be used in three ways: 1) as a free function to replace the working tape, 2) as a context manager within which a new tape is set as the working tape, diff --git a/tests/firedrake_adjoint/test_burgers_newton.py b/tests/firedrake_adjoint/test_burgers_newton.py index 0fa12e22..a0e48959 100644 --- a/tests/firedrake_adjoint/test_burgers_newton.py +++ b/tests/firedrake_adjoint/test_burgers_newton.py @@ -95,14 +95,15 @@ def J(ic, solve_type, timestep, steps, V): solver = NonlinearVariationalSolver(problem) tape = get_working_tape() + J = 0.0 for _ in tape.timestepper(range(steps)): if solve_type == "NLVS": solver.solve() else: solve(F == 0, u, bc) u_.assign(u) - - return assemble(u_*u_*dx + ic*ic*dx) + J += assemble(u_*u_*dx + ic*ic*dx) + return J @pytest.mark.parametrize("solve_type, checkpointing", diff --git a/tests/firedrake_adjoint/test_external_modification.py b/tests/firedrake_adjoint/test_external_modification.py index 90b1fb39..58221311 100644 --- a/tests/firedrake_adjoint/test_external_modification.py +++ b/tests/firedrake_adjoint/test_external_modification.py @@ -14,7 +14,7 @@ def test_external_modification(): v1 = Function(fs) v2 = Function(fs) - u.assign(1.) + u.interpolate(1.) v1.project(u) with stop_annotating(modifies=u): u.dat.data[:] = 2. @@ -23,4 +23,4 @@ def test_external_modification(): J = assemble(v1*dx + v2*dx) Jhat = ReducedFunctional(J, Control(u)) - assert np.allclose(J, Jhat(2)) + assert np.allclose(J, Jhat(Function(fs).interpolate(2.))) diff --git a/tests/firedrake_adjoint/test_tlm.py b/tests/firedrake_adjoint/test_tlm.py index e80ecd5f..d5bc832d 100644 --- a/tests/firedrake_adjoint/test_tlm.py +++ b/tests/firedrake_adjoint/test_tlm.py @@ -66,7 +66,7 @@ def test_tlm_bc(): c.block_variable.tlm_value = Function(R, val=1) tape.evaluate_tlm() - assert (taylor_test(Jhat, Constant(c), Constant(1), dJdm=J.block_variable.tlm_value) > 1.9) + assert (taylor_test(Jhat, c, Function(R, val=1), dJdm=J.block_variable.tlm_value) > 1.9) def test_tlm_func(): @@ -234,7 +234,7 @@ def test_projection(): k.block_variable.tlm_value = Constant(1) tape.evaluate_tlm() - assert(taylor_test(Jhat, Constant(k), Constant(1), dJdm=J.block_variable.tlm_value) > 1.9) + assert(taylor_test(Jhat, k, Function(R, val=1), dJdm=J.block_variable.tlm_value) > 1.9) def test_projection_function():