Skip to content

Commit

Permalink
Return AdjFloat of an AdjFloat summation (#181)
Browse files Browse the repository at this point in the history
* Return AdjFloat of an AdjFloat summation and check the control update in rf
  • Loading branch information
Ig-dolci authored Nov 28, 2024
1 parent 16e6434 commit 3415139
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 8 deletions.
3 changes: 2 additions & 1 deletion pyadjoint/adjfloat.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,8 @@ def __init__(self, *args):
self.add_dependency(dep)

def recompute_component(self, inputs, block_variable, idx, prepared):
return self.operator(*(term.saved_output for term in self.terms))
output = self.operator(*(term.saved_output for term in self.terms))
return self._outputs[0].saved_output._ad_convert_type(output)

def __str__(self):
return f"{self.terms[0]} {self.symbol} {self.terms[1]}"
Expand Down
16 changes: 16 additions & 0 deletions pyadjoint/reduced_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .enlisting import Enlist
from .tape import get_working_tape, stop_annotating, no_annotations
from .overloaded_type import OverloadedType, create_overloaded_object
from .adjfloat import AdjFloat


def _get_extract_derivative_components(derivative_components):
Expand Down Expand Up @@ -196,6 +197,21 @@ def __call__(self, values):
if len(values) != len(self.controls):
raise ValueError("values should be a list of same length as controls.")

for i, value in enumerate(values):
control_type = type(self.controls[i].control)
if isinstance(value, (int, float)) and control_type is AdjFloat:
value = self.controls[i].control._ad_convert_type(value)
elif not isinstance(value, control_type):
if len(values) == 1:
raise TypeError(
"Control value must be an `OverloadedType` object with the same "
f"type as the control, which is {control_type}"
)
else:
raise TypeError(
f"The control at index {i} must be an `OverloadedType` object "
f"with the same type as the control, which is {control_type}"
)
# Call callback.
self.eval_cb_pre(self.controls.delist(values))

Expand Down
2 changes: 1 addition & 1 deletion pyadjoint/tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def continue_annotation():

class set_working_tape(ContextDecorator):
"""Set a new tape as the working tape.
This class can be used in three ways:
1) as a free function to replace the working tape,
2) as a context manager within which a new tape is set as the working tape,
Expand Down
5 changes: 3 additions & 2 deletions tests/firedrake_adjoint/test_burgers_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,15 @@ def J(ic, solve_type, timestep, steps, V):
solver = NonlinearVariationalSolver(problem)

tape = get_working_tape()
J = 0.0
for _ in tape.timestepper(range(steps)):
if solve_type == "NLVS":
solver.solve()
else:
solve(F == 0, u, bc)
u_.assign(u)

return assemble(u_*u_*dx + ic*ic*dx)
J += assemble(u_*u_*dx + ic*ic*dx)
return J


@pytest.mark.parametrize("solve_type, checkpointing",
Expand Down
4 changes: 2 additions & 2 deletions tests/firedrake_adjoint/test_external_modification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_external_modification():
v1 = Function(fs)
v2 = Function(fs)

u.assign(1.)
u.interpolate(1.)
v1.project(u)
with stop_annotating(modifies=u):
u.dat.data[:] = 2.
Expand All @@ -23,4 +23,4 @@ def test_external_modification():
J = assemble(v1*dx + v2*dx)
Jhat = ReducedFunctional(J, Control(u))

assert np.allclose(J, Jhat(2))
assert np.allclose(J, Jhat(Function(fs).interpolate(2.)))
4 changes: 2 additions & 2 deletions tests/firedrake_adjoint/test_tlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_tlm_bc():
c.block_variable.tlm_value = Function(R, val=1)
tape.evaluate_tlm()

assert (taylor_test(Jhat, Constant(c), Constant(1), dJdm=J.block_variable.tlm_value) > 1.9)
assert (taylor_test(Jhat, c, Function(R, val=1), dJdm=J.block_variable.tlm_value) > 1.9)


def test_tlm_func():
Expand Down Expand Up @@ -234,7 +234,7 @@ def test_projection():

k.block_variable.tlm_value = Constant(1)
tape.evaluate_tlm()
assert(taylor_test(Jhat, Constant(k), Constant(1), dJdm=J.block_variable.tlm_value) > 1.9)
assert(taylor_test(Jhat, k, Function(R, val=1), dJdm=J.block_variable.tlm_value) > 1.9)


def test_projection_function():
Expand Down

0 comments on commit 3415139

Please sign in to comment.