Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReducedFunctional.optimize_tape will optionally _replace_ self.tape with the optimized tape, without modifying original tape. #171

Closed
wants to merge 6 commits into from
1 change: 1 addition & 0 deletions pyadjoint/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def assign_numpy(self, dst, src, offset):
def fetch_numpy(self, value):
return self.control._ad_to_list(value)

# TODO: This should be self.block_variable.checkpoint?
def copy_data(self):
return self.control._ad_copy()

Expand Down
15 changes: 10 additions & 5 deletions pyadjoint/reduced_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def __init__(self, functional, controls,
if not isinstance(functional, OverloadedType):
raise TypeError("Functional must be an OverloadedType.")
self.functional = functional
self.tape = get_working_tape() if tape is None else tape
self.controls = Enlist(controls)
self.derivative_components = derivative_components
self.scale = scale
Expand All @@ -88,6 +87,13 @@ def __init__(self, functional, controls,
self.hessian_cb_pre = hessian_cb_pre
self.hessian_cb_post = hessian_cb_post

tape = get_working_tape() if tape is None else tape
self.tape = tape.copy()
self.tape.optimize(
controls=self.controls,
functionals=[self.functional]
)

if self.derivative_components:
# pre callback
self.derivative_cb_pre = _get_extract_derivative_components(
Expand Down Expand Up @@ -225,10 +231,9 @@ def __call__(self, values):
return func_value

def optimize_tape(self):
self.tape.optimize(
controls=self.controls,
functionals=[self.functional]
)
# Tape already optimized in __init__
# TODO: What should we do here now?
return self.tape

def marked_controls(self):
return marked_controls(self)
Expand Down
19 changes: 17 additions & 2 deletions pyadjoint/tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,10 +368,25 @@ def copy(self):

"""
# TODO: Offer deepcopying. But is it feasible memory wise to copy all checkpoints?
return Tape(

# TODO: firedrake.DiskCheckPointer doesn't implement
# copy so we have to do something else here.
package_data = {}
for k, v in self._package_data.items():
try:
package_data[k] = v.copy()
except NotImplementedError:
package_data[k] = v
tape = Tape(
blocks=self._blocks,
package_data={k: v.copy() for k, v in self._package_data.items()}
package_data=package_data
)
if self._checkpoint_manager is not None:
tape._checkpoint_manager = self._checkpoint_manager
tape.latest_checkpoint = self.latest_checkpoint
if self._bar is not _NullProgressBar:
tape.progress_bar = self.progress_bar
return tape

def checkpoint_block_vars(self, controls=[], tag=None):
"""Returns an object to checkpoint the current state of all block variables on the tape.
Expand Down