From 8576a1e3c3d452a0fdfb8a8bc0b90611872cc303 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sat, 12 Oct 2024 00:19:26 -0700 Subject: [PATCH] Block-Jacobi preconditioning, Eisenstat-Walker for inexact steps (#19) * Block-Jacobi preconditioning * Nits * Add missing jdc.Static[] * Fix * Implement Eisenstat-Walker --- README.md | 17 +++-- src/jaxls/_preconditioning.py | 126 ++++++++++++++++++++++++++++++++++ src/jaxls/_solvers.py | 117 +++++++++++++++++++++++-------- 3 files changed, 224 insertions(+), 36 deletions(-) create mode 100644 src/jaxls/_preconditioning.py diff --git a/README.md b/README.md index 7264c0d..4c5eef3 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![pyright](https://github.com/brentyi/jaxls/actions/workflows/pyright.yml/badge.svg)](https://github.com/brentyi/jaxls/actions/workflows/pyright.yml) -_status: working! see limitations [here](#limitations)_ +_status: working! see limitations [here](#limitations)_ **`jaxls`** is a library for nonlinear least squares in JAX. @@ -11,18 +11,20 @@ problems. We accelerate optimization by analyzing the structure of graphs: repeated factor and variable types are vectorized, and the sparsity of adjacency in the graph is translated into sparse matrix operations. -Features: +Currently supported: - Automatic sparse Jacobians. - Optimization on manifolds; SO(2), SO(3), SE(2), and SE(3) implementations included. - Nonlinear solvers: Levenberg-Marquardt and Gauss-Newton. -- Linear solvers: both direct (sparse Cholesky via CHOLMOD, on CPU) and - iterative (Jacobi-preconditioned Conjugate Gradient). +- Direct linear solves via sparse Cholesky / CHOLMOD, on CPU. +- Iterative linear solves via Conjugate Gradient. + - Preconditioning: block and point Jacobi. + - Inexact Newton via Eisenstat-Walker. -Use cases are primarily in least squares problems that are inherently (1) sparse -and (2) inefficient to solve with gradient-based methods. In robotics, these are -ubiquitous across classical approaches to perception, planning, and control. +Use cases are primarily in least squares problems that are inherently (1) +sparse and (2) inefficient to solve with gradient-based methods. These are +common in robotics. For the first iteration of this library, written for [IROS 2021](https://github.com/brentyi/dfgo), see @@ -122,6 +124,7 @@ print("Pose 1", solution[pose_vars[1]]) ### Limitations There are many practical features that we don't currently support: + - GPU accelerated Cholesky factorization. (for CHOLMOD we wrap [scikit-sparse](https://scikit-sparse.readthedocs.io/en/latest/), which runs on CPU only) - Covariance estimation / marginalization. - Incremental solves. diff --git a/src/jaxls/_preconditioning.py b/src/jaxls/_preconditioning.py new file mode 100644 index 0000000..1edd891 --- /dev/null +++ b/src/jaxls/_preconditioning.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +import jax +from jax import numpy as jnp + +if TYPE_CHECKING: + from ._factor_graph import FactorGraph + from ._sparse_matrices import BlockRowSparseMatrix + + +def make_point_jacobi_precoditioner( + A_blocksparse: BlockRowSparseMatrix, +) -> Callable[[jax.Array], jax.Array]: + """Returns a point Jacobi (diagonal) preconditioner.""" + ATA_diagonals = jnp.zeros(A_blocksparse.shape[1]) + + for block_row in A_blocksparse.block_rows: + (n_blocks, rows, cols_concat) = block_row.blocks_concat.shape + del rows + del cols_concat + assert block_row.blocks_concat.ndim == 3 # (N_block, rows, cols) + assert block_row.start_cols[0].shape == (n_blocks,) + block_l2_cols = jnp.sum(block_row.blocks_concat**2, axis=1).flatten() + indices = jnp.concatenate( + [ + (start_col[:, None] + jnp.arange(width)[None, :]) + for start_col, width in zip( + block_row.start_cols, block_row.block_widths + ) + ], + axis=1, + ).flatten() + ATA_diagonals = ATA_diagonals.at[indices].add(block_l2_cols) + + return lambda vec: vec / ATA_diagonals + + +def make_block_jacobi_precoditioner( + graph: FactorGraph, A_blocksparse: BlockRowSparseMatrix +) -> Callable[[jax.Array], jax.Array]: + """Returns a block Jacobi preconditioner.""" + + # This list will store block diagonal gram matrices corresponding to each + # variable. + gram_diagonal_blocks = list[jax.Array]() + for var_type, ids in graph.tangent_ordering.ordered_dict_items( + graph.sorted_ids_from_var_type + ): + (num_vars,) = ids.shape + gram_diagonal_blocks.append( + jnp.zeros((num_vars, var_type.tangent_dim, var_type.tangent_dim)) + + jnp.eye(var_type.tangent_dim) * 1e-6 + ) + + assert len(graph.stacked_factors) == len(A_blocksparse.block_rows) + for factor, block_row in zip(graph.stacked_factors, A_blocksparse.block_rows): + assert block_row.blocks_concat.ndim == 3 # (N_block, rows, cols) + + # Current index we're looking at in the blocks_concat array. + start_concat_col = 0 + + for var_type, ids in graph.tangent_ordering.ordered_dict_items( + factor.sorted_ids_from_var_type + ): + (num_factors, num_vars) = ids.shape + var_type_idx = graph.tangent_ordering.order_from_type[var_type] + + # Extract the blocks corresponding to the current variable type. + end_concat_col = start_concat_col + num_vars * var_type.tangent_dim + A_blocks = block_row.blocks_concat[ + :, :, start_concat_col:end_concat_col + ].reshape( + ( + num_factors, + factor.residual_dim, + num_vars, + var_type.tangent_dim, + ) + ) + + # f: factor, r: residual, v: variable, t/a: tangent + gram_blocks = jnp.einsum("frvt,frva->fvta", A_blocks, A_blocks) + assert gram_blocks.shape == ( + num_factors, + num_vars, + factor.residual_dim, + factor.residual_dim, + ) + + start_concat_col = end_concat_col + del end_concat_col + + gram_diagonal_blocks[var_type_idx] = ( + gram_diagonal_blocks[var_type_idx] + .at[jnp.searchsorted(graph.sorted_ids_from_var_type[var_type], ids)] + .add(gram_blocks) + ) + + inv_block_diagonals = [ + jnp.linalg.inv(batched_block) for batched_block in gram_diagonal_blocks + ] + + def preconditioner(vec: jax.Array) -> jax.Array: + """Compute block Jacobi preconditioning.""" + precond_parts = [] + offset = 0 + for inv_batched_block in inv_block_diagonals: + num_blocks, block_dim, block_dim_ = inv_batched_block.shape + assert block_dim == block_dim_ + precond_parts.append( + jnp.einsum( + "bij,bj->bi", + inv_batched_block, + vec[offset : offset + num_blocks * block_dim].reshape( + (num_blocks, block_dim) + ), + ).flatten() + ) + offset += num_blocks * block_dim + out = jnp.concatenate(precond_parts, axis=0) + assert out.shape == vec.shape + return out + + return preconditioner diff --git a/src/jaxls/_solvers.py b/src/jaxls/_solvers.py index 05b4834..8716531 100644 --- a/src/jaxls/_solvers.py +++ b/src/jaxls/_solvers.py @@ -1,7 +1,7 @@ from __future__ import annotations import functools -from typing import TYPE_CHECKING, Callable, Hashable, cast +from typing import TYPE_CHECKING, Callable, Hashable, Literal, assert_never, cast import jax import jax.experimental.sparse @@ -12,7 +12,12 @@ import sksparse.cholmod from jax import numpy as jnp -from ._sparse_matrices import SparseCooMatrix, SparseCsrMatrix +from jaxls._preconditioning import ( + make_block_jacobi_precoditioner, + make_point_jacobi_precoditioner, +) + +from ._sparse_matrices import BlockRowSparseMatrix, SparseCsrMatrix from ._variables import VarTypeOrdering, VarValues from .utils import jax_log @@ -75,26 +80,73 @@ def _solve_on_host( @jdc.pytree_dataclass -class ConjugateGradientLinearSolver: - """Iterative solver for sparse linear systems. Can run on CPU or GPU.""" +class ConjugateGradientState: + """State used for Eisenstat-Walker criterion in ConjugateGradientLinearSolver.""" + + ATb_norm_prev: float | jax.Array + """Previous norm of ATb.""" + eta: float | jax.Array + """Current tolerance.""" - tolerance: float = 1e-7 - inexact_step_eta: float | None = 1e-2 - """Forcing sequence parameter for inexact Newton steps. CG tolerance is set to - `eta / iteration #`. - For reference, see AN INEXACT LEVENBERG-MARQUARDT METHOD FOR LARGE SPARSE NONLINEAR - LEAST SQUARES, Wright & Holt 1983.""" +@jdc.pytree_dataclass +class ConjugateGradientLinearSolver: + """Iterative solver for sparse linear systems. Can run on CPU or GPU. + + For inexact steps, we use the Eisenstat-Walker criterion. For reference, + see "Choosing the Forcing Terms in an Inexact Newton Method", Eisenstat & + Walker, 1996." + """ + + tolerance_min: float = 1e-7 + tolerance_max: float = 1e-2 + + eisenstat_walker_gamma: float = 0.9 + """Eisenstat-Walker criterion gamma term. Controls how quickly the tolerance + decreases. Typical values range from 0.5 to 0.9. Higher values lead to more + aggressive tolerance reduction.""" + eisenstat_walker_alpha: float = 2.0 + """ Eisenstat-Walker criterion alpha term. Determines rate at which the + tolerance changes based on residual reduction. Typical values are 1.5 or + 2.0. Higher values make the tolerance more sensitive to residual changes.""" + + preconditioner: jdc.Static[Literal["block-jacobi", "point-jacobi"] | None] = ( + "block-jacobi" + ) + """Preconditioner to use for linear solves.""" def _solve( self, + graph: FactorGraph, + A_blocksparse: BlockRowSparseMatrix, ATA_multiply: Callable[[jax.Array], jax.Array], - ATA_diagonals: jax.Array, ATb: jax.Array, - iterations: int | jax.Array, - ) -> jax.Array: + prev_linear_state: ConjugateGradientState, + ) -> tuple[jax.Array, ConjugateGradientState]: assert len(ATb.shape) == 1, "ATb should be 1D!" + # Preconditioning setup. + if self.preconditioner == "block-jacobi": + preconditioner = make_block_jacobi_precoditioner(graph, A_blocksparse) + elif self.preconditioner == "point-jacobi": + preconditioner = make_point_jacobi_precoditioner(A_blocksparse) + elif self.preconditioner is None: + preconditioner = lambda x: x + else: + assert_never(self.preconditioner) + + # Calculate tolerance using Eisenstat-Walker criterion. + ATb_norm = jnp.linalg.norm(ATb) + current_eta = jnp.minimum( + self.eisenstat_walker_gamma + * (ATb_norm / (prev_linear_state.ATb_norm_prev + 1e-7)) + ** self.eisenstat_walker_alpha, + self.tolerance_max, + ) + current_eta = jnp.maximum( + self.tolerance_min, jnp.minimum(current_eta, prev_linear_state.eta) + ) + # Solve with conjugate gradient. initial_x = jnp.zeros(ATb.shape) solution_values, _ = jax.scipy.sparse.linalg.cg( @@ -103,15 +155,12 @@ def _solve( x0=initial_x, # https://en.wikipedia.org/wiki/Conjugate_gradient_method#Convergence_properties maxiter=len(initial_x), - tol=cast( - float, - jnp.maximum(self.tolerance, self.inexact_step_eta / (iterations + 1)), - ) - if self.inexact_step_eta is not None - else self.tolerance, - M=lambda x: x / ATA_diagonals, # Jacobi preconditioner. + tol=cast(float, current_eta), + M=preconditioner, + ) + return solution_values, ConjugateGradientState( + ATb_norm_prev=ATb_norm, eta=current_eta ) - return solution_values # Nonlinear solvers. @@ -126,6 +175,8 @@ class NonlinearSolverState: done: bool | jax.Array lambd: float | jax.Array + linear_state: ConjugateGradientState | None + @jdc.pytree_dataclass class NonlinearSolver: @@ -149,6 +200,11 @@ def solve(self, graph: FactorGraph, initial_vals: VarValues) -> VarValues: lambd=self.trust_region.lambda_initial if self.trust_region is not None else 0.0, + linear_state=None + if isinstance(self.linear_solver, CholmodLinearSolver) + else ConjugateGradientState( + ATb_norm_prev=0.0, eta=self.linear_solver.tolerance_max + ), ) # Optimization. @@ -190,18 +246,17 @@ def step( # Compute right-hand side of normal equation. ATb = -AT_multiply(state.residual_vector) + linear_state = None if isinstance(self.linear_solver, ConjugateGradientLinearSolver): - # Get diagonals of ATA for preconditioning. - ATA_diagonals = ( - jnp.zeros_like(ATb).at[graph.jac_coords_coo.cols].add(jac_values**2) - ) - local_delta = self.linear_solver._solve( + assert isinstance(state.linear_state, ConjugateGradientState) + local_delta, linear_state = self.linear_solver._solve( + graph, + A_blocksparse, # We could also use (lambd * ATA_diagonals * vec) for # scale-invariant damping. But this is hard to match with CHOLMOD. lambda vec: AT_multiply(A_multiply(vec)) + state.lambd * vec, - ATA_diagonals, - ATb, - iterations=state.iterations, + ATb=ATb, + prev_linear_state=state.linear_state, ) elif isinstance(self.linear_solver, CholmodLinearSolver): A_csr = SparseCsrMatrix(jac_values, graph.jac_coords_csr) @@ -239,6 +294,10 @@ def step( proposed_residual_vector = graph.compute_residual_vector(vals) proposed_cost = jnp.sum(proposed_residual_vector**2) + # Update ATb_norm for Eisenstat-Walker criterion. + if linear_state is not None: + state_next.linear_state = linear_state + # Always accept Gauss-Newton steps. if self.trust_region is None: state_next.vals = vals