Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the optmisation problem for firedrake and pyrol. #139

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
23 changes: 17 additions & 6 deletions pyadjoint/optimization/rol_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def scale(self, alpha):
def riesz_map(self, derivs):
dat = []
opts = {"riesz_representation": self.inner_product}
for deriv in Enlist(derivs):
dat.append(deriv._ad_convert_type(deriv, options=opts))
for f, deriv in zip(self.dat, derivs):
dat.append(f._ad_convert_type(deriv, options=opts))
return dat

def dot(self, yy):
Expand All @@ -82,6 +82,15 @@ def dot(self, yy):
res += x._ad_dot(y, options=opts)
return res

def dual(self) -> "ROLVector":
"""Create a new `ROLVector` in the dual space of the current `self`.
"""
dat = []
opts = {"riesz_map": self.inner_product}
for x in self.dat:
dat.append(x._riesz_representation(options=opts))
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
return ROLVector(dat, inner_product=self.inner_product)

def norm(self):
return self.dot(self) ** 0.5

Expand Down Expand Up @@ -123,12 +132,14 @@ def applyJacobian(self, jv, v, x, tol):
self.con.jacobian_action(x.dat, v.dat[0], jv.dat)

def applyAdjointJacobian(self, jv, v, x, tol):
self.con.jacobian_adjoint_action(x.dat, v.dat, jv.dat[0])
jv.dat = jv.riesz_map(jv.dat)
tmp = jv.dual()
self.con.jacobian_adjoint_action(x.dat, v.dat, tmp.dat[0])
jv.dat = jv.riesz_map(tmp.dat)

def applyAdjointHessian(self, ahuv, u, v, x, tol):
self.con.hessian_action(x.dat, u.dat[0], v.dat, ahuv.dat[0])
ahuv.dat = ahuv.riesz_map(ahuv.dat)
tmp = ahuv.dual()
self.con.hessian_action(x.dat, u.dat[0], v.dat, tmp.dat[0])
ahuv.dat = ahuv.riesz_map(tmp.dat)

class ROLSolver(OptimizationSolver):
"""
Expand Down
14 changes: 14 additions & 0 deletions pyadjoint/overloaded_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,20 @@ def _ad_convert_type(self, value, options={}):
"""
raise NotImplementedError(f"OverloadedType._ad_convert_type not defined for class {type(self)}.")

def _riesz_representation(self, options={}):
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
"""This method must be overridden.

Should implement a way to return the Riesz representation of the overloaded object.

Args:
options (dict): A dictionary with options that may be supplied by the user. If the Riesz representation
functionality offers some options on how to compute it, this is the dictionary that should be used.

Returns:
OverloadedType: The Riesz representation of the overloaded object.
"""
raise NotImplementedError(f"OverloadedType._riesz_representation not defined for class {type(self)}.")

def _ad_create_checkpoint(self):
"""This method must be overridden.

Expand Down