Skip to content

Commit

Permalink
Merge pull request #591 from tlm-adjoint/jrmaddison/optimization_update
Browse files Browse the repository at this point in the history
Update optimization module
  • Loading branch information
jrmaddison authored Jul 19, 2024
2 parents 818d433 + 33cda51 commit 419414e
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 1,000 deletions.
13 changes: 0 additions & 13 deletions docs/source/acknowledgements.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,6 @@ and in the supporting information for
Journal of Geophysical Research: Oceans, 125(11), e2020JC016370, 2020,
doi: 10.1029/2020JC016370

L-BFGS
``````

The file `tlm_adjoint/optimization.py
<autoapi/tlm_adjoint/optimization/index.html>`_ includes an implementation of
the L-BFGS algorithm, described in

- Jorge Nocedal and Stephen J. Wright, 'Numerical optimization', Springer, New
York, NY, 2006, Second edition, doi: 10.1007/978-0-387-40065-5
- Richard H. Byrd, Peihuang Lu, and Jorge Nocedal, and Ciyou Zhu, 'A limited
memory algorithm for bound constrained optimization', SIAM Journal on
Scientific Computing, 16(5), 1190--1208, 1995, doi: 10.1137/0916069

PyTorch
```````

Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/8_hessian_uq.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@
"source": [
"optimizer = TAOSolver(lambda m: forward(m)[2], space, H_0_action=B_action,\n",
" solver_parameters={\"tao_type\": \"lmvm\",\n",
" \"tao_gatol\": 1.0e-4,\n",
" \"tao_gatol\": 1.0e-5,\n",
" \"tao_grtol\": 0.0,\n",
" \"tao_gttol\": 0.0,\n",
" \"tao_monitor\": None})\n",
Expand Down
160 changes: 0 additions & 160 deletions tests/fenics/test_minimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@ def scipy_trust_ncg_minimization(forward, M0):
return M


def l_bfgs_minimization(forward, M0):
M, _ = minimize_l_bfgs(
forward, M0, s_atol=0.0, g_atol=1.0e-11)
return M


def tao_lmvm_minimization(forward, m0):
return minimize_tao(forward, m0,
solver_parameters={"tao_type": "lmvm",
Expand All @@ -51,7 +45,6 @@ def tao_nls_minimization(forward, m0):
@pytest.mark.fenics
@pytest.mark.parametrize("minimize", [scipy_l_bfgs_b_minimization,
scipy_trust_ncg_minimization,
pytest.param(l_bfgs_minimization, marks=pytest.mark.xfail), # noqa: E501
tao_lmvm_minimization,
pytest.param(tao_nls_minimization, marks=pytest.mark.xfail)]) # noqa: E501
@pytest.mark.skipif(complex_mode, reason="real only")
Expand Down Expand Up @@ -96,7 +89,6 @@ def forward_J(alpha):
@pytest.mark.fenics
@pytest.mark.parametrize("minimize", [scipy_l_bfgs_b_minimization,
scipy_trust_ncg_minimization,
pytest.param(l_bfgs_minimization, marks=pytest.mark.xfail), # noqa: E501
tao_lmvm_minimization,
pytest.param(tao_nls_minimization, marks=pytest.mark.xfail)]) # noqa: E501
@pytest.mark.skipif(complex_mode, reason="real only")
Expand Down Expand Up @@ -151,155 +143,3 @@ def forward_J(alpha, beta):
var_assign(error, beta_ref)
var_axpy(error, -1.0, beta)
assert var_linf_norm(error) < 1.0e-8


@pytest.mark.fenics
@pytest.mark.skipif(complex_mode, reason="real only")
@pytest.mark.xfail
@seed_test
def test_l_bfgs_single(setup_test, test_leaks):
mesh = UnitSquareMesh(3, 3)
X = SpatialCoordinate(mesh)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
M_l = Function(space, name="M_l", space_type="conjugate_dual")
assemble(test * dx, tensor=M_l)

x_star = Function(space, name="x_star")
interpolate_expression(x_star, sin(pi * X[0]) * sin(2.0 * pi * X[1]))

def F(x):
check_space_type(x, "primal")
return assemble(0.5 * inner(x - x_star, x - x_star) * dx)

def Fp(x):
check_space_type(x, "primal")
Fp = Function(space, name="Fp", space_type="conjugate_dual")
assemble(inner(x - x_star, test) * dx, tensor=Fp)
return Fp

def H_0_action(x):
check_space_type(x, "conjugate_dual")
H_0_action = Function(space, name="H_0_action")
var_set_values(H_0_action,
var_get_values(x)
/ var_get_values(M_l))
return H_0_action

def B_0_action(x):
check_space_type(x, "primal")
B_0_action = Function(space, name="B_0_action",
space_type="conjugate_dual")
var_set_values(B_0_action,
var_get_values(x)
* var_get_values(M_l))
return B_0_action

x0 = Function(space, name="x0")
x, (its, F_calls, Fp_calls, _) = l_bfgs(
F, Fp, x0, m=30, s_atol=0.0, g_atol=1.0e-12,
H_0_action=H_0_action, M_action=B_0_action, M_inv_action=H_0_action)

error = var_copy(x, name="error")
var_axpy(error, -1.0, x_star)
error_norm = var_linf_norm(error)
info(f"{error_norm=:.6e}")
info(f"{F_calls=:d}")
info(f"{Fp_calls=:d}")

assert abs(F(x)) < 1.0e-25
assert error_norm < 1.0e-12
assert its == 15
assert F_calls == 17
assert Fp_calls == 17


@pytest.mark.fenics
@pytest.mark.skipif(complex_mode, reason="real only")
@pytest.mark.xfail
@seed_test
def test_l_bfgs_multiple(setup_test, test_leaks):
mesh = UnitSquareMesh(3, 3)
X = SpatialCoordinate(mesh)
space_x = FunctionSpace(mesh, "Lagrange", 1)
space_y = FunctionSpace(mesh, "Discontinuous Lagrange", 1)
test_x = TestFunction(space_x)
test_y = TestFunction(space_y)
M_l_x = Function(space_x, name="M_l_x", space_type="conjugate_dual")
M_l_y = Function(space_y, name="M_l_y", space_type="conjugate_dual")
assemble(test_x * dx, tensor=M_l_x)
assemble(test_y * dx, tensor=M_l_y)

x_star = Function(space_x, name="x_star")
interpolate_expression(x_star, sin(pi * X[0]) * sin(2.0 * pi * X[1]))
y_star = Function(space_y, name="y_star")
interpolate_expression(y_star, exp(X[0]) * exp(X[1]))
alpha_y = (1.0 + X[0]) * (1.0 + X[0])

def F(x, y):
check_space_type(x, "primal")
check_space_type(y, "primal")
return assemble(0.5 * inner(x - x_star, x - x_star) * dx
+ 0.5 * inner(y - y_star, alpha_y * (y - y_star)) * dx)

def Fp(x, y):
check_space_type(x, "primal")
check_space_type(y, "primal")
Fp = (Function(space_x, name="Fp_0", space_type="conjugate_dual"),
Function(space_y, name="Fp_1", space_type="conjugate_dual"))
assemble(inner(x - x_star, test_x) * dx, tensor=Fp[0])
assemble(inner(alpha_y * (y - y_star), test_y) * dx, tensor=Fp[1])
return Fp

def H_0_action(x, y):
check_space_type(x, "conjugate_dual")
check_space_type(y, "conjugate_dual")
H_0_action = (Function(space_x, name="H_0_action_0"),
Function(space_y, name="H_0_action_1"))
var_set_values(H_0_action[0],
var_get_values(x)
/ var_get_values(M_l_x))
var_set_values(H_0_action[1],
var_get_values(y)
/ var_get_values(M_l_y))
return H_0_action

def B_0_action(x, y):
check_space_type(x, "primal")
check_space_type(y, "primal")
B_0_action = (Function(space_x, name="B_0_action_0",
space_type="conjugate_dual"),
Function(space_y, name="B_0_action_1",
space_type="conjugate_dual"))
var_set_values(B_0_action[0],
var_get_values(x)
* var_get_values(M_l_x))
var_set_values(B_0_action[1],
var_get_values(y)
* var_get_values(M_l_y))
return B_0_action

x0 = Function(space_x, name="x0")
y0 = Function(space_y, name="y0")
(x, y), (its, F_calls, Fp_calls, _) = \
l_bfgs(F, Fp, (x0, y0), m=30, s_atol=0.0, g_atol=1.0e-12,
H_0_action=H_0_action,
M_action=B_0_action, M_inv_action=H_0_action)

x_error = var_copy(x, name="x_error")
var_axpy(x_error, -1.0, x_star)
x_error_norm = var_linf_norm(x_error)
y_error = var_copy(y, name="y_error")
var_axpy(y_error, -1.0, y_star)
y_error_norm = var_linf_norm(y_error)
info(f"{x_error_norm=:.6e}")
info(f"{y_error_norm=:.6e}")
info(f"{F_calls=:d}")
info(f"{Fp_calls=:d}")

assert abs(F(x, y)) < 1.0e-23
assert x_error_norm < 1.0e-11
assert y_error_norm < 1.0e-11
assert its <= 38
assert F_calls <= 42
assert Fp_calls <= 42
9 changes: 6 additions & 3 deletions tests/firedrake/test_hessian_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,13 @@ def forward_J(m):
sin(2.0 * pi * X[0]) * sin(3.0 * pi * X[1]) * exp(4.0 * X[0] * X[1]))

m0 = Function(space, name="m0")
m, _ = minimize_l_bfgs(
m = minimize_tao(
forward_J, m0,
s_atol=0.0, g_atol=1.0e-7,
H_0_action=B, M_action=B_inv, M_inv_action=B)
solver_parameters={"tao_type": "lmvm",
"tao_gatol": 1.0e-7,
"tao_grtol": 0.0,
"tao_gttol": 0.0},
H_0_action=B)

b_ref = Cofunction(space.dual(), name="b_ref")
assemble(inner((sin(5.0 * pi * X[0]) * sin(7.0 * pi * X[1])) ** 2, test) * dx, # noqa: E501
Expand Down
Loading

0 comments on commit 419414e

Please sign in to comment.