From d82cbf004b229fb8220a9b1d2c2b7d2777d9a9a7 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Fri, 11 Oct 2024 23:09:01 -0700 Subject: [PATCH 1/5] Block-Jacobi preconditioning --- src/jaxls/_preconditioning.py | 126 ++++++++++++++++++++++++++++++++++ src/jaxls/_solvers.py | 41 +++++++---- 2 files changed, 155 insertions(+), 12 deletions(-) create mode 100644 src/jaxls/_preconditioning.py diff --git a/src/jaxls/_preconditioning.py b/src/jaxls/_preconditioning.py new file mode 100644 index 0000000..a4742da --- /dev/null +++ b/src/jaxls/_preconditioning.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +import jax +from jax import numpy as jnp + +if TYPE_CHECKING: + from ._factor_graph import FactorGraph + from ._sparse_matrices import BlockRowSparseMatrix + + +def make_jacobi_precoditioner( + A_blocksparse: BlockRowSparseMatrix, +) -> Callable[[jax.Array], jax.Array]: + """Returns a diagonal Jacobi preconditioner.""" + ATA_diagonals = jnp.zeros(A_blocksparse.shape[1]) + + for block_row in A_blocksparse.block_rows: + (n_blocks, rows, cols_concat) = block_row.blocks_concat.shape + del rows + del cols_concat + assert block_row.blocks_concat.ndim == 3 # (N_block, rows, cols) + assert block_row.start_cols[0].shape == (n_blocks,) + block_l2_cols = jnp.sum(block_row.blocks_concat**2, axis=1).flatten() + indices = jnp.concatenate( + [ + (start_col[:, None] + jnp.arange(width)[None, :]) + for start_col, width in zip( + block_row.start_cols, block_row.block_widths + ) + ], + axis=1, + ).flatten() + ATA_diagonals = ATA_diagonals.at[indices].add(block_l2_cols) + + return lambda vec: vec / ATA_diagonals + + +def make_block_jacobi_precoditioner( + graph: FactorGraph, A_blocksparse: BlockRowSparseMatrix +) -> Callable[[jax.Array], jax.Array]: + """Returns a Block-Jacobi preconditioner.""" + + # This list will store block diagonal gram matrices corresponding to each + # variable. + gram_diagonal_blocks = list[jax.Array]() + for var_type, ids in graph.tangent_ordering.ordered_dict_items( + graph.sorted_ids_from_var_type + ): + (num_vars,) = ids.shape + gram_diagonal_blocks.append( + jnp.zeros((num_vars, var_type.tangent_dim, var_type.tangent_dim)) + + jnp.eye(var_type.tangent_dim) * 1e-6 + ) + + assert len(graph.stacked_factors) == len(A_blocksparse.block_rows) + for factor, block_row in zip(graph.stacked_factors, A_blocksparse.block_rows): + assert block_row.blocks_concat.ndim == 3 # (N_block, rows, cols) + + # Current index we're looking at in the blocks_concat array. + start_concat_col = 0 + + for var_type, ids in graph.tangent_ordering.ordered_dict_items( + factor.sorted_ids_from_var_type + ): + (num_factors, num_vars) = ids.shape + var_type_idx = graph.tangent_ordering.order_from_type[var_type] + + # Extract the blocks corresponding to the current variable type. + end_concat_col = start_concat_col + num_vars * var_type.tangent_dim + A_blocks = block_row.blocks_concat[ + :, :, start_concat_col:end_concat_col + ].reshape( + ( + num_factors, + factor.residual_dim, + num_vars, + var_type.tangent_dim, + ) + ) + + # f: factor, r: residual, v: variable, t/a: tangent + gram_blocks = jnp.einsum("frvt,frva->fvta", A_blocks, A_blocks) + assert gram_blocks.shape == ( + num_factors, + num_vars, + factor.residual_dim, + factor.residual_dim, + ) + + start_concat_col = end_concat_col + del end_concat_col + + gram_diagonal_blocks[var_type_idx] = ( + gram_diagonal_blocks[var_type_idx] + .at[jnp.searchsorted(graph.sorted_ids_from_var_type[var_type], ids)] + .add(gram_blocks) + ) + + inv_block_diagonals = [ + jnp.linalg.inv(batched_block) for batched_block in gram_diagonal_blocks + ] + + def preconditioner(vec: jax.Array) -> jax.Array: + """Compute Block-Jacobi preconditioning.""" + precond_parts = [] + offset = 0 + for inv_batched_block in inv_block_diagonals: + num_blocks, block_dim, block_dim_ = inv_batched_block.shape + assert block_dim == block_dim_ + precond_parts.append( + jnp.einsum( + "bij,bj->bi", + inv_batched_block, + vec[offset : offset + num_blocks * block_dim].reshape( + (num_blocks, block_dim) + ), + ).flatten() + ) + offset += num_blocks * block_dim + out = jnp.concatenate(precond_parts, axis=0) + assert out.shape == vec.shape + return out + + return preconditioner diff --git a/src/jaxls/_solvers.py b/src/jaxls/_solvers.py index 05b4834..b640b25 100644 --- a/src/jaxls/_solvers.py +++ b/src/jaxls/_solvers.py @@ -1,7 +1,7 @@ from __future__ import annotations import functools -from typing import TYPE_CHECKING, Callable, Hashable, cast +from typing import TYPE_CHECKING, Callable, Hashable, Literal, assert_never, cast import jax import jax.experimental.sparse @@ -12,7 +12,12 @@ import sksparse.cholmod from jax import numpy as jnp -from ._sparse_matrices import SparseCooMatrix, SparseCsrMatrix +from jaxls._preconditioning import ( + make_block_jacobi_precoditioner, + make_jacobi_precoditioner, +) + +from ._sparse_matrices import BlockRowSparseMatrix, SparseCsrMatrix from ._variables import VarTypeOrdering, VarValues from .utils import jax_log @@ -79,22 +84,36 @@ class ConjugateGradientLinearSolver: """Iterative solver for sparse linear systems. Can run on CPU or GPU.""" tolerance: float = 1e-7 - inexact_step_eta: float | None = 1e-2 + inexact_step_eta: float | None = None # 1e-2 """Forcing sequence parameter for inexact Newton steps. CG tolerance is set to `eta / iteration #`. For reference, see AN INEXACT LEVENBERG-MARQUARDT METHOD FOR LARGE SPARSE NONLINEAR LEAST SQUARES, Wright & Holt 1983.""" + preconditioner: Literal["block-jacobi", "jacobi"] | None = "block-jacobi" + """Preconditioner to use for linear solves.""" + def _solve( self, + graph: FactorGraph, + A_blocksparse: BlockRowSparseMatrix, ATA_multiply: Callable[[jax.Array], jax.Array], - ATA_diagonals: jax.Array, ATb: jax.Array, iterations: int | jax.Array, ) -> jax.Array: assert len(ATb.shape) == 1, "ATb should be 1D!" + # Preconditioning setup. + if self.preconditioner == "block-jacobi": + preconditioner = make_block_jacobi_precoditioner(graph, A_blocksparse) + elif self.preconditioner == "jacobi": + preconditioner = make_jacobi_precoditioner(A_blocksparse) + elif self.preconditioner is None: + preconditioner = lambda x: x + else: + assert_never(self.preconditioner) + # Solve with conjugate gradient. initial_x = jnp.zeros(ATb.shape) solution_values, _ = jax.scipy.sparse.linalg.cg( @@ -109,7 +128,8 @@ def _solve( ) if self.inexact_step_eta is not None else self.tolerance, - M=lambda x: x / ATA_diagonals, # Jacobi preconditioner. + M=preconditioner, + # M=lambda x: x / ATA_diagonals, # Jacobi preconditioner. ) return solution_values @@ -191,16 +211,13 @@ def step( ATb = -AT_multiply(state.residual_vector) if isinstance(self.linear_solver, ConjugateGradientLinearSolver): - # Get diagonals of ATA for preconditioning. - ATA_diagonals = ( - jnp.zeros_like(ATb).at[graph.jac_coords_coo.cols].add(jac_values**2) - ) local_delta = self.linear_solver._solve( + graph, + A_blocksparse, # We could also use (lambd * ATA_diagonals * vec) for # scale-invariant damping. But this is hard to match with CHOLMOD. lambda vec: AT_multiply(A_multiply(vec)) + state.lambd * vec, - ATA_diagonals, - ATb, + ATb=ATb, iterations=state.iterations, ) elif isinstance(self.linear_solver, CholmodLinearSolver): @@ -308,7 +325,7 @@ class TrustRegionConfig: @jdc.pytree_dataclass class TerminationConfig: # Termination criteria. - max_iterations: int = 100 + max_iterations: int = 10 # 100 cost_tolerance: float = 1e-6 """We terminate if `|cost change| / cost < cost_tolerance`.""" gradient_tolerance: float = 1e-7 From 52b9f2fde0bbd9365f5c43c52fe8482006c97176 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Fri, 11 Oct 2024 23:13:21 -0700 Subject: [PATCH 2/5] Nits --- README.md | 12 +++++++----- src/jaxls/_preconditioning.py | 8 ++++---- src/jaxls/_solvers.py | 8 ++++---- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 7264c0d..f912162 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![pyright](https://github.com/brentyi/jaxls/actions/workflows/pyright.yml/badge.svg)](https://github.com/brentyi/jaxls/actions/workflows/pyright.yml) -_status: working! see limitations [here](#limitations)_ +_status: working! see limitations [here](#limitations)_ **`jaxls`** is a library for nonlinear least squares in JAX. @@ -18,11 +18,12 @@ Features: included. - Nonlinear solvers: Levenberg-Marquardt and Gauss-Newton. - Linear solvers: both direct (sparse Cholesky via CHOLMOD, on CPU) and - iterative (Jacobi-preconditioned Conjugate Gradient). + iterative (Conjugate Gradient). +- Preconditioning: block and point Jacobi. -Use cases are primarily in least squares problems that are inherently (1) sparse -and (2) inefficient to solve with gradient-based methods. In robotics, these are -ubiquitous across classical approaches to perception, planning, and control. +Use cases are primarily in least squares problems that are inherently (1) +sparse and (2) inefficient to solve with gradient-based methods. These are +common in robotics. For the first iteration of this library, written for [IROS 2021](https://github.com/brentyi/dfgo), see @@ -122,6 +123,7 @@ print("Pose 1", solution[pose_vars[1]]) ### Limitations There are many practical features that we don't currently support: + - GPU accelerated Cholesky factorization. (for CHOLMOD we wrap [scikit-sparse](https://scikit-sparse.readthedocs.io/en/latest/), which runs on CPU only) - Covariance estimation / marginalization. - Incremental solves. diff --git a/src/jaxls/_preconditioning.py b/src/jaxls/_preconditioning.py index a4742da..1edd891 100644 --- a/src/jaxls/_preconditioning.py +++ b/src/jaxls/_preconditioning.py @@ -10,10 +10,10 @@ from ._sparse_matrices import BlockRowSparseMatrix -def make_jacobi_precoditioner( +def make_point_jacobi_precoditioner( A_blocksparse: BlockRowSparseMatrix, ) -> Callable[[jax.Array], jax.Array]: - """Returns a diagonal Jacobi preconditioner.""" + """Returns a point Jacobi (diagonal) preconditioner.""" ATA_diagonals = jnp.zeros(A_blocksparse.shape[1]) for block_row in A_blocksparse.block_rows: @@ -40,7 +40,7 @@ def make_jacobi_precoditioner( def make_block_jacobi_precoditioner( graph: FactorGraph, A_blocksparse: BlockRowSparseMatrix ) -> Callable[[jax.Array], jax.Array]: - """Returns a Block-Jacobi preconditioner.""" + """Returns a block Jacobi preconditioner.""" # This list will store block diagonal gram matrices corresponding to each # variable. @@ -103,7 +103,7 @@ def make_block_jacobi_precoditioner( ] def preconditioner(vec: jax.Array) -> jax.Array: - """Compute Block-Jacobi preconditioning.""" + """Compute block Jacobi preconditioning.""" precond_parts = [] offset = 0 for inv_batched_block in inv_block_diagonals: diff --git a/src/jaxls/_solvers.py b/src/jaxls/_solvers.py index b640b25..f068a72 100644 --- a/src/jaxls/_solvers.py +++ b/src/jaxls/_solvers.py @@ -14,7 +14,7 @@ from jaxls._preconditioning import ( make_block_jacobi_precoditioner, - make_jacobi_precoditioner, + make_point_jacobi_precoditioner, ) from ._sparse_matrices import BlockRowSparseMatrix, SparseCsrMatrix @@ -91,7 +91,7 @@ class ConjugateGradientLinearSolver: For reference, see AN INEXACT LEVENBERG-MARQUARDT METHOD FOR LARGE SPARSE NONLINEAR LEAST SQUARES, Wright & Holt 1983.""" - preconditioner: Literal["block-jacobi", "jacobi"] | None = "block-jacobi" + preconditioner: Literal["block-jacobi", "point-jacobi"] | None = "block-jacobi" """Preconditioner to use for linear solves.""" def _solve( @@ -107,8 +107,8 @@ def _solve( # Preconditioning setup. if self.preconditioner == "block-jacobi": preconditioner = make_block_jacobi_precoditioner(graph, A_blocksparse) - elif self.preconditioner == "jacobi": - preconditioner = make_jacobi_precoditioner(A_blocksparse) + elif self.preconditioner == "point-jacobi": + preconditioner = make_point_jacobi_precoditioner(A_blocksparse) elif self.preconditioner is None: preconditioner = lambda x: x else: From 7c9ecd1fd3e26b4152ccd9e3e7bfebf46586feb2 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Fri, 11 Oct 2024 23:16:13 -0700 Subject: [PATCH 3/5] Add missing jdc.Static[] --- src/jaxls/_solvers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxls/_solvers.py b/src/jaxls/_solvers.py index f068a72..4b12482 100644 --- a/src/jaxls/_solvers.py +++ b/src/jaxls/_solvers.py @@ -91,7 +91,7 @@ class ConjugateGradientLinearSolver: For reference, see AN INEXACT LEVENBERG-MARQUARDT METHOD FOR LARGE SPARSE NONLINEAR LEAST SQUARES, Wright & Holt 1983.""" - preconditioner: Literal["block-jacobi", "point-jacobi"] | None = "block-jacobi" + preconditioner: jdc.Static[Literal["block-jacobi", "point-jacobi"]] | None = "block-jacobi" """Preconditioner to use for linear solves.""" def _solve( From 2c41aefae7811e15be2f7a70ee7e3901c2938b90 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Fri, 11 Oct 2024 23:17:15 -0700 Subject: [PATCH 4/5] Fix --- src/jaxls/_solvers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/jaxls/_solvers.py b/src/jaxls/_solvers.py index 4b12482..b29a952 100644 --- a/src/jaxls/_solvers.py +++ b/src/jaxls/_solvers.py @@ -91,7 +91,9 @@ class ConjugateGradientLinearSolver: For reference, see AN INEXACT LEVENBERG-MARQUARDT METHOD FOR LARGE SPARSE NONLINEAR LEAST SQUARES, Wright & Holt 1983.""" - preconditioner: jdc.Static[Literal["block-jacobi", "point-jacobi"]] | None = "block-jacobi" + preconditioner: jdc.Static[Literal["block-jacobi", "point-jacobi"] | None] = ( + "block-jacobi" + ) """Preconditioner to use for linear solves.""" def _solve( From ea8e2da204521d96aca1258a2cb164a91fb709de Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sat, 12 Oct 2024 00:05:56 -0700 Subject: [PATCH 5/5] Implement Eisenstat-Walker --- README.md | 9 ++--- src/jaxls/_solvers.py | 80 ++++++++++++++++++++++++++++++++----------- 2 files changed, 65 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index f912162..4c5eef3 100644 --- a/README.md +++ b/README.md @@ -11,15 +11,16 @@ problems. We accelerate optimization by analyzing the structure of graphs: repeated factor and variable types are vectorized, and the sparsity of adjacency in the graph is translated into sparse matrix operations. -Features: +Currently supported: - Automatic sparse Jacobians. - Optimization on manifolds; SO(2), SO(3), SE(2), and SE(3) implementations included. - Nonlinear solvers: Levenberg-Marquardt and Gauss-Newton. -- Linear solvers: both direct (sparse Cholesky via CHOLMOD, on CPU) and - iterative (Conjugate Gradient). -- Preconditioning: block and point Jacobi. +- Direct linear solves via sparse Cholesky / CHOLMOD, on CPU. +- Iterative linear solves via Conjugate Gradient. + - Preconditioning: block and point Jacobi. + - Inexact Newton via Eisenstat-Walker. Use cases are primarily in least squares problems that are inherently (1) sparse and (2) inefficient to solve with gradient-based methods. These are diff --git a/src/jaxls/_solvers.py b/src/jaxls/_solvers.py index b29a952..8716531 100644 --- a/src/jaxls/_solvers.py +++ b/src/jaxls/_solvers.py @@ -79,17 +79,36 @@ def _solve_on_host( return factor.solve_A(ATb) +@jdc.pytree_dataclass +class ConjugateGradientState: + """State used for Eisenstat-Walker criterion in ConjugateGradientLinearSolver.""" + + ATb_norm_prev: float | jax.Array + """Previous norm of ATb.""" + eta: float | jax.Array + """Current tolerance.""" + + @jdc.pytree_dataclass class ConjugateGradientLinearSolver: - """Iterative solver for sparse linear systems. Can run on CPU or GPU.""" + """Iterative solver for sparse linear systems. Can run on CPU or GPU. - tolerance: float = 1e-7 - inexact_step_eta: float | None = None # 1e-2 - """Forcing sequence parameter for inexact Newton steps. CG tolerance is set to - `eta / iteration #`. + For inexact steps, we use the Eisenstat-Walker criterion. For reference, + see "Choosing the Forcing Terms in an Inexact Newton Method", Eisenstat & + Walker, 1996." + """ - For reference, see AN INEXACT LEVENBERG-MARQUARDT METHOD FOR LARGE SPARSE NONLINEAR - LEAST SQUARES, Wright & Holt 1983.""" + tolerance_min: float = 1e-7 + tolerance_max: float = 1e-2 + + eisenstat_walker_gamma: float = 0.9 + """Eisenstat-Walker criterion gamma term. Controls how quickly the tolerance + decreases. Typical values range from 0.5 to 0.9. Higher values lead to more + aggressive tolerance reduction.""" + eisenstat_walker_alpha: float = 2.0 + """ Eisenstat-Walker criterion alpha term. Determines rate at which the + tolerance changes based on residual reduction. Typical values are 1.5 or + 2.0. Higher values make the tolerance more sensitive to residual changes.""" preconditioner: jdc.Static[Literal["block-jacobi", "point-jacobi"] | None] = ( "block-jacobi" @@ -102,8 +121,8 @@ def _solve( A_blocksparse: BlockRowSparseMatrix, ATA_multiply: Callable[[jax.Array], jax.Array], ATb: jax.Array, - iterations: int | jax.Array, - ) -> jax.Array: + prev_linear_state: ConjugateGradientState, + ) -> tuple[jax.Array, ConjugateGradientState]: assert len(ATb.shape) == 1, "ATb should be 1D!" # Preconditioning setup. @@ -116,6 +135,18 @@ def _solve( else: assert_never(self.preconditioner) + # Calculate tolerance using Eisenstat-Walker criterion. + ATb_norm = jnp.linalg.norm(ATb) + current_eta = jnp.minimum( + self.eisenstat_walker_gamma + * (ATb_norm / (prev_linear_state.ATb_norm_prev + 1e-7)) + ** self.eisenstat_walker_alpha, + self.tolerance_max, + ) + current_eta = jnp.maximum( + self.tolerance_min, jnp.minimum(current_eta, prev_linear_state.eta) + ) + # Solve with conjugate gradient. initial_x = jnp.zeros(ATb.shape) solution_values, _ = jax.scipy.sparse.linalg.cg( @@ -124,16 +155,12 @@ def _solve( x0=initial_x, # https://en.wikipedia.org/wiki/Conjugate_gradient_method#Convergence_properties maxiter=len(initial_x), - tol=cast( - float, - jnp.maximum(self.tolerance, self.inexact_step_eta / (iterations + 1)), - ) - if self.inexact_step_eta is not None - else self.tolerance, + tol=cast(float, current_eta), M=preconditioner, - # M=lambda x: x / ATA_diagonals, # Jacobi preconditioner. ) - return solution_values + return solution_values, ConjugateGradientState( + ATb_norm_prev=ATb_norm, eta=current_eta + ) # Nonlinear solvers. @@ -148,6 +175,8 @@ class NonlinearSolverState: done: bool | jax.Array lambd: float | jax.Array + linear_state: ConjugateGradientState | None + @jdc.pytree_dataclass class NonlinearSolver: @@ -171,6 +200,11 @@ def solve(self, graph: FactorGraph, initial_vals: VarValues) -> VarValues: lambd=self.trust_region.lambda_initial if self.trust_region is not None else 0.0, + linear_state=None + if isinstance(self.linear_solver, CholmodLinearSolver) + else ConjugateGradientState( + ATb_norm_prev=0.0, eta=self.linear_solver.tolerance_max + ), ) # Optimization. @@ -212,15 +246,17 @@ def step( # Compute right-hand side of normal equation. ATb = -AT_multiply(state.residual_vector) + linear_state = None if isinstance(self.linear_solver, ConjugateGradientLinearSolver): - local_delta = self.linear_solver._solve( + assert isinstance(state.linear_state, ConjugateGradientState) + local_delta, linear_state = self.linear_solver._solve( graph, A_blocksparse, # We could also use (lambd * ATA_diagonals * vec) for # scale-invariant damping. But this is hard to match with CHOLMOD. lambda vec: AT_multiply(A_multiply(vec)) + state.lambd * vec, ATb=ATb, - iterations=state.iterations, + prev_linear_state=state.linear_state, ) elif isinstance(self.linear_solver, CholmodLinearSolver): A_csr = SparseCsrMatrix(jac_values, graph.jac_coords_csr) @@ -258,6 +294,10 @@ def step( proposed_residual_vector = graph.compute_residual_vector(vals) proposed_cost = jnp.sum(proposed_residual_vector**2) + # Update ATb_norm for Eisenstat-Walker criterion. + if linear_state is not None: + state_next.linear_state = linear_state + # Always accept Gauss-Newton steps. if self.trust_region is None: state_next.vals = vals @@ -327,7 +367,7 @@ class TrustRegionConfig: @jdc.pytree_dataclass class TerminationConfig: # Termination criteria. - max_iterations: int = 10 # 100 + max_iterations: int = 100 cost_tolerance: float = 1e-6 """We terminate if `|cost change| / cost < cost_tolerance`.""" gradient_tolerance: float = 1e-7