diff --git a/README.md b/README.md index 4c5eef3..51a8e1f 100644 --- a/README.md +++ b/README.md @@ -11,25 +11,25 @@ problems. We accelerate optimization by analyzing the structure of graphs: repeated factor and variable types are vectorized, and the sparsity of adjacency in the graph is translated into sparse matrix operations. +Use cases are primarily in least squares problems that are (1) sparse and (2) +inefficient to solve with gradient-based methods. + Currently supported: - Automatic sparse Jacobians. -- Optimization on manifolds; SO(2), SO(3), SE(2), and SE(3) implementations - included. +- Optimization on manifolds. + - Examples provided for SO(2), SO(3), SE(2), and SE(3). - Nonlinear solvers: Levenberg-Marquardt and Gauss-Newton. -- Direct linear solves via sparse Cholesky / CHOLMOD, on CPU. -- Iterative linear solves via Conjugate Gradient. - - Preconditioning: block and point Jacobi. - - Inexact Newton via Eisenstat-Walker. - -Use cases are primarily in least squares problems that are inherently (1) -sparse and (2) inefficient to solve with gradient-based methods. These are -common in robotics. - -For the first iteration of this library, written for -[IROS 2021](https://github.com/brentyi/dfgo), see -[jaxfg](https://github.com/brentyi/jaxfg). `jaxls` is a rewrite that aims to be -faster and easier to use. For additional references, see inspirations like +- Linear subproblem solvers: + - Sparse direct with Cholesky / CHOLMOD, on CPU. + - Sparse iterative with Conjugate Gradient. + - Preconditioning: block and point Jacobi. + - Inexact Newton via Eisenstat-Walker. + - Dense Cholesky for smaller problems. + +For the first iteration of this library, written for [IROS 2021](https://github.com/brentyi/dfgo), see +[jaxfg](https://github.com/brentyi/jaxfg). `jaxls` is a rewrite that is faster +and easier to use. For additional references, see inspirations like [GTSAM](https://gtsam.org/), [Ceres Solver](http://ceres-solver.org/), [minisam](https://github.com/dongjing3309/minisam), [SwiftFusion](https://github.com/borglab/SwiftFusion), diff --git a/examples/pose_graph_g2o.py b/examples/pose_graph_g2o.py index 5c7d71f..6da94cb 100755 --- a/examples/pose_graph_g2o.py +++ b/examples/pose_graph_g2o.py @@ -19,19 +19,15 @@ def main( g2o_path: pathlib.Path = pathlib.Path(__file__).parent / "data/input_M3500_g2o.g2o", - linear_solver: Literal["cholmod", "pcg"] = "cholmod", + linear_solver: Literal[ + "cholmod", "conjugate_gradient", "dense_cholesky" + ] = "cholmod", ) -> None: # Parse g2o file. with jaxls.utils.stopwatch("Reading g2o file"): g2o: _g2o_utils.G2OData = _g2o_utils.parse_g2o(g2o_path) jax.block_until_ready(g2o) - # Get linear solver. - linear_solver_config = { - "cholmod": jaxls.CholmodLinearSolver, - "pcg": jaxls.ConjugateGradientLinearSolver, - }[linear_solver]() - # Making graph. with jaxls.utils.stopwatch("Making graph"): graph = jaxls.FactorGraph.make(factors=g2o.factors, variables=g2o.pose_vars) @@ -47,12 +43,12 @@ def main( with jaxls.utils.stopwatch("Running solve"): solution_vals = graph.solve( - initial_vals, trust_region=None, linear_solver=linear_solver_config + initial_vals, trust_region=None, linear_solver=linear_solver ) with jaxls.utils.stopwatch("Running solve (again)"): solution_vals = graph.solve( - initial_vals, trust_region=None, linear_solver=linear_solver_config + initial_vals, trust_region=None, linear_solver=linear_solver ) # Plot diff --git a/examples/pose_graph_simple.py b/examples/pose_graph_simple.py index bec585e..b0829c9 100644 --- a/examples/pose_graph_simple.py +++ b/examples/pose_graph_simple.py @@ -48,7 +48,7 @@ graph = jaxls.FactorGraph.make(factors, vars) # Solve the optimization problem. -solution = graph.solve(linear_solver=jaxls.ConjugateGradientLinearSolver()) +solution = graph.solve() print("All solutions", solution) print("Pose 0", solution[vars[0]]) print("Pose 1", solution[vars[1]]) diff --git a/src/jaxls/__init__.py b/src/jaxls/__init__.py index 01e910f..c80c198 100644 --- a/src/jaxls/__init__.py +++ b/src/jaxls/__init__.py @@ -5,8 +5,7 @@ from ._lie_group_variables import SE3Var as SE3Var from ._lie_group_variables import SO2Var as SO2Var from ._lie_group_variables import SO3Var as SO3Var -from ._solvers import CholmodLinearSolver as CholmodLinearSolver -from ._solvers import ConjugateGradientLinearSolver as ConjugateGradientLinearSolver +from ._solvers import ConjugateGradientConfig as ConjugateGradientConfig from ._solvers import TerminationConfig as TerminationConfig from ._solvers import TrustRegionConfig as TrustRegionConfig from ._variables import Var as Var diff --git a/src/jaxls/_factor_graph.py b/src/jaxls/_factor_graph.py index 6d7c6ac..16b2b11 100644 --- a/src/jaxls/_factor_graph.py +++ b/src/jaxls/_factor_graph.py @@ -13,15 +13,14 @@ from typing_extensions import deprecated from ._solvers import ( - CholmodLinearSolver, - ConjugateGradientLinearSolver, + ConjugateGradientConfig, NonlinearSolver, TerminationConfig, TrustRegionConfig, ) from ._sparse_matrices import ( BlockRowSparseMatrix, - MatrixBlockRow, + SparseBlockRow, SparseCooCoordinates, SparseCsrCoordinates, ) @@ -56,8 +55,9 @@ class FactorGraph: def solve( self, initial_vals: VarValues | None = None, - linear_solver: CholmodLinearSolver - | ConjugateGradientLinearSolver = CholmodLinearSolver(), + *, + linear_solver: Literal["cholmod", "conjugate_gradient", "dense_cholesky"] + | ConjugateGradientConfig = "cholmod", trust_region: TrustRegionConfig | None = TrustRegionConfig(), termination: TerminationConfig = TerminationConfig(), verbose: bool = True, @@ -68,7 +68,19 @@ def solve( initial_vals = VarValues.make( var_type(ids) for var_type, ids in self.sorted_ids_from_var_type.items() ) - solver = NonlinearSolver(linear_solver, trust_region, termination, verbose) + + # In our internal API, linear_solver needs to always be a string. The + # conjugate gradient config is a separate field. This is more + # convenient to implement, because then the former can be static while + # the latter is a pytree. + conjugate_gradient_config = None + if isinstance(linear_solver, ConjugateGradientConfig): + conjugate_gradient_config = linear_solver + linear_solver = "conjugate_gradient" + + solver = NonlinearSolver( + linear_solver, trust_region, termination, conjugate_gradient_config, verbose + ) return solver.solve(graph=self, initial_vals=initial_vals) def compute_residual_vector(self, vals: VarValues) -> jax.Array: @@ -84,7 +96,7 @@ def compute_residual_vector(self, vals: VarValues) -> jax.Array: return jnp.concatenate(residual_slices, axis=0) def _compute_jac_values(self, vals: VarValues) -> BlockRowSparseMatrix: - block_rows = list[MatrixBlockRow]() + block_rows = list[SparseBlockRow]() residual_offset = 0 for factor in self.stacked_factors: @@ -151,11 +163,10 @@ def compute_jac_with_perturb(factor: _AnalyzedFactor) -> jax.Array: assert stacked_jac.shape[-1] == stacked_jac_start_col block_rows.append( - MatrixBlockRow( - start_row=jnp.arange(num_factor) * factor.residual_dim - + residual_offset, + SparseBlockRow( + num_cols=self.tangent_dim, start_cols=tuple(start_cols), - block_widths=tuple(block_widths), + block_num_cols=tuple(block_widths), blocks_concat=stacked_jac, ) ) diff --git a/src/jaxls/_preconditioning.py b/src/jaxls/_preconditioning.py index 9f9c9fe..2368740 100644 --- a/src/jaxls/_preconditioning.py +++ b/src/jaxls/_preconditioning.py @@ -25,9 +25,9 @@ def make_point_jacobi_precoditioner( block_l2_cols = jnp.sum(block_row.blocks_concat**2, axis=1).flatten() indices = jnp.concatenate( [ - (start_col[:, None] + jnp.arange(width)[None, :]) - for start_col, width in zip( - block_row.start_cols, block_row.block_widths + (start_col[:, None] + jnp.arange(block_cols)[None, :]) + for start_col, block_cols in zip( + block_row.start_cols, block_row.block_num_cols ) ], axis=1, diff --git a/src/jaxls/_solvers.py b/src/jaxls/_solvers.py index d55cb6d..7367024 100644 --- a/src/jaxls/_solvers.py +++ b/src/jaxls/_solvers.py @@ -28,59 +28,57 @@ _cholmod_analyze_cache: dict[Hashable, sksparse.cholmod.Factor] = {} -@jdc.pytree_dataclass -class CholmodLinearSolver: - """Direct solver for sparse linear systems. Runs on CPU.""" +def _cholmod_solve( + A: SparseCsrMatrix, ATb: jax.Array, lambd: float | jax.Array +) -> jax.Array: + """JIT-friendly linear solve using CHOLMOD.""" + return jax.pure_callback( + _cholmod_solve_on_host, + ATb, # Result shape/dtype. + A, + ATb, + lambd, + vectorized=False, + ) - def _solve( - self, A: SparseCsrMatrix, ATb: jax.Array, lambd: float | jax.Array - ) -> jax.Array: - return jax.pure_callback( - self._solve_on_host, - ATb, # Result shape/dtype. - A, - ATb, - lambd, - vectorized=False, - ) - def _solve_on_host( - self, - A: SparseCsrMatrix, - ATb: jax.Array, - lambd: float | jax.Array, - ) -> jax.Array: - # Matrix is transposed when we convert CSR to CSC. - A_T_scipy = scipy.sparse.csc_matrix( - (A.values, A.coords.indices, A.coords.indptr), shape=A.coords.shape[::-1] - ) +def _cholmod_solve_on_host( + A: SparseCsrMatrix, + ATb: jax.Array, + lambd: float | jax.Array, +) -> jax.Array: + """Solve a linear system using CHOLMOD. Should be called on the host.""" + # Matrix is transposed when we convert CSR to CSC. + A_T_scipy = scipy.sparse.csc_matrix( + (A.values, A.coords.indices, A.coords.indptr), shape=A.coords.shape[::-1] + ) - # Cache sparsity pattern analysis. - cache_key = ( - A.coords.indices.tobytes(), - A.coords.indptr.tobytes(), - A.coords.shape, - ) - factor = _cholmod_analyze_cache.get(cache_key, None) - if factor is None: - factor = sksparse.cholmod.analyze_AAt(A_T_scipy) - _cholmod_analyze_cache[cache_key] = factor - - max_cache_size = 512 - if len(_cholmod_analyze_cache) > max_cache_size: - _cholmod_analyze_cache.pop(next(iter(_cholmod_analyze_cache))) - - # Factorize and solve - factor = factor.cholesky_AAt( - A_T_scipy, - # Some simple linear problems blow up without this 1e-5 term. - beta=lambd + 1e-5, - ) - return factor.solve_A(ATb) + # Cache sparsity pattern analysis. + cache_key = ( + A.coords.indices.tobytes(), + A.coords.indptr.tobytes(), + A.coords.shape, + ) + factor = _cholmod_analyze_cache.get(cache_key, None) + if factor is None: + factor = sksparse.cholmod.analyze_AAt(A_T_scipy) + _cholmod_analyze_cache[cache_key] = factor + + max_cache_size = 512 + if len(_cholmod_analyze_cache) > max_cache_size: + _cholmod_analyze_cache.pop(next(iter(_cholmod_analyze_cache))) + + # Factorize and solve + factor = factor.cholesky_AAt( + A_T_scipy, + # Some simple linear problems blow up without this 1e-5 term. + beta=lambd + 1e-5, + ) + return factor.solve_A(ATb) @jdc.pytree_dataclass -class ConjugateGradientState: +class _ConjugateGradientState: """State used for Eisenstat-Walker criterion in ConjugateGradientLinearSolver.""" ATb_norm_prev: float | jax.Array @@ -90,7 +88,7 @@ class ConjugateGradientState: @jdc.pytree_dataclass -class ConjugateGradientLinearSolver: +class ConjugateGradientConfig: """Iterative solver for sparse linear systems. Can run on CPU or GPU. For inexact steps, we use the Eisenstat-Walker criterion. For reference, @@ -110,8 +108,8 @@ class ConjugateGradientLinearSolver: tolerance changes based on residual reduction. Typical values are 1.5 or 2.0. Higher values make the tolerance more sensitive to residual changes.""" - preconditioner: jdc.Static[Literal["block-jacobi", "point-jacobi"] | None] = ( - "block-jacobi" + preconditioner: jdc.Static[Literal["block_jacobi", "point_jacobi"] | None] = ( + "block_jacobi" ) """Preconditioner to use for linear solves.""" @@ -121,14 +119,14 @@ def _solve( A_blocksparse: BlockRowSparseMatrix, ATA_multiply: Callable[[jax.Array], jax.Array], ATb: jax.Array, - prev_linear_state: ConjugateGradientState, - ) -> tuple[jax.Array, ConjugateGradientState]: + prev_linear_state: _ConjugateGradientState, + ) -> tuple[jax.Array, _ConjugateGradientState]: assert len(ATb.shape) == 1, "ATb should be 1D!" # Preconditioning setup. - if self.preconditioner == "block-jacobi": + if self.preconditioner == "block_jacobi": preconditioner = make_block_jacobi_precoditioner(graph, A_blocksparse) - elif self.preconditioner == "point-jacobi": + elif self.preconditioner == "point_jacobi": preconditioner = make_point_jacobi_precoditioner(A_blocksparse) elif self.preconditioner is None: preconditioner = lambda x: x @@ -158,7 +156,7 @@ def _solve( tol=cast(float, current_eta), M=preconditioner, ) - return solution_values, ConjugateGradientState( + return solution_values, _ConjugateGradientState( ATb_norm_prev=ATb_norm, eta=current_eta ) @@ -177,22 +175,26 @@ class NonlinearSolverState: lambd: float | jax.Array # Conjugate gradient state. Not used for other solvers. - cg_state: ConjugateGradientState | None + cg_state: _ConjugateGradientState | None @jdc.pytree_dataclass class NonlinearSolver: """Helper class for solving using Gauss-Newton or Levenberg-Marquardt.""" - linear_solver: CholmodLinearSolver | ConjugateGradientLinearSolver + linear_solver: jdc.Static[ + Literal["cholmod", "dense_cholesky", "conjugate_gradient"] + ] trust_region: TrustRegionConfig | None termination: TerminationConfig + conjugate_gradient_config: ConjugateGradientConfig | None verbose: jdc.Static[bool] @jdc.jit def solve(self, graph: FactorGraph, initial_vals: VarValues) -> VarValues: vals = initial_vals residual_vector = graph.compute_residual_vector(vals) + state = NonlinearSolverState( iterations=0, vals=vals, @@ -204,9 +206,14 @@ def solve(self, graph: FactorGraph, initial_vals: VarValues) -> VarValues: if self.trust_region is not None else 0.0, cg_state=None - if isinstance(self.linear_solver, CholmodLinearSolver) - else ConjugateGradientState( - ATb_norm_prev=0.0, eta=self.linear_solver.tolerance_max + if self.linear_solver != "conjugate_gradient" + else _ConjugateGradientState( + ATb_norm_prev=0.0, + eta=( + ConjugateGradientConfig() + if self.conjugate_gradient_config is None + else self.conjugate_gradient_config + ).tolerance_max, ), ) @@ -254,9 +261,18 @@ def step( ATb = -AT_multiply(state.residual_vector) linear_state = None - if isinstance(self.linear_solver, ConjugateGradientLinearSolver): - assert isinstance(state.cg_state, ConjugateGradientState) - local_delta, linear_state = self.linear_solver._solve( + if ( + isinstance(self.linear_solver, ConjugateGradientConfig) + or self.linear_solver == "conjugate_gradient" + ): + # Use default CG config is specified as a string, otherwise use the provided config. + cg_config = ( + ConjugateGradientConfig() + if self.linear_solver == "conjugate_gradient" + else self.linear_solver + ) + assert isinstance(state.cg_state, _ConjugateGradientState) + local_delta, linear_state = cg_config._solve( graph, A_blocksparse, # We could also use (lambd * ATA_diagonals * vec) for @@ -265,11 +281,19 @@ def step( ATb=ATb, prev_linear_state=state.cg_state, ) - elif isinstance(self.linear_solver, CholmodLinearSolver): + elif self.linear_solver == "cholmod": + # Use CHOLMOD for direct solve. A_csr = SparseCsrMatrix(jac_values, graph.jac_coords_csr) - local_delta = self.linear_solver._solve(A_csr, ATb, lambd=state.lambd) + local_delta = _cholmod_solve(A_csr, ATb, lambd=state.lambd) + elif self.linear_solver == "dense_cholesky": + A_dense = A_blocksparse.to_dense() + ATA = A_dense.T @ A_dense + diag_idx = jnp.arange(ATA.shape[0]) + ATA = ATA.at[diag_idx, diag_idx].add(state.lambd) + cho_factor = jax.scipy.linalg.cho_factor(ATA) + local_delta = jax.scipy.linalg.cho_solve(cho_factor, ATb) else: - assert False + assert_never(self.linear_solver) vals = state.vals._retract(local_delta, graph.tangent_ordering) if self.verbose: @@ -321,6 +345,17 @@ def step( if linear_state is not None: state_next.cg_state = linear_state + # Compute termination criteria. + state_next.termination_criteria, state_next.termination_deltas = ( + self.termination._check_convergence( + state, + cost_updated=proposed_cost, + tangent=local_delta, + tangent_ordering=graph.tangent_ordering, + ATb=ATb, + ) + ) + # Always accept Gauss-Newton steps. if self.trust_region is None: state_next.vals = vals @@ -338,6 +373,11 @@ def step( ) accept_flag = step_quality >= self.trust_region.step_quality_min + # Should not terminate if we're rejecting step. + state_next.termination_criteria = jnp.logical_and( + accept_flag, state_next.termination_criteria + ) + state_next.vals = jax.tree_map( lambda proposed, current: jnp.where(accept_flag, proposed, current), vals, @@ -362,15 +402,6 @@ def step( ) state_next.iterations += 1 - state_next.termination_criteria, state_next.termination_deltas = ( - self.termination._check_convergence( - state, - cost_updated=state_next.cost, - tangent=local_delta, - tangent_ordering=graph.tangent_ordering, - ATb=ATb, - ) - ) return state_next diff --git a/src/jaxls/_sparse_matrices.py b/src/jaxls/_sparse_matrices.py index 2e58b9a..5824b9e 100644 --- a/src/jaxls/_sparse_matrices.py +++ b/src/jaxls/_sparse_matrices.py @@ -9,23 +9,68 @@ @jdc.pytree_dataclass -class MatrixBlockRow: - start_row: jax.Array - """Row indices of the start of each block. Shape should be `(num_blocks,)`.""" +class SparseBlockRow: + """A sparse block-row. Each block-row contains: + + - A set of N blocks with shape (rows, cols_i), for i=1...N. + - Every block has an equal number of rows. + - Blocks can have different numbers of columns. We store these in + the `block_num_cols` attribute. + - Concatenated values are stored in `blocks_concat`. + - An initial column index for each block in the block row. + - We store these in `start_cols`. + + A `num_block_rows` leading axis will often be prepended to all contained + arrays. In this case, the `SparseBlockRow` structure represents multiple + sequential block rows. Each resulting block-row has the same block count + and widths; they would otherwise not be stackable. However, their sparsity + patterns may vary due to different values held in `start_cols`. + """ + + num_cols: jdc.Static[int] + """Total width of the block-row, including columns with zero values.""" start_cols: tuple[jax.Array, ...] - """Column indices of the start of each block.""" - block_widths: jdc.Static[tuple[int, ...]] - """Width of each block in the block-row.""" + """Column indices of the start of each block. Shape in tuple should be + `([num_block_rows],)`.""" + block_num_cols: jdc.Static[tuple[int, ...]] + """# of columns for each block in the block-row.""" blocks_concat: jax.Array - """Blocks of matrix, concatenated along the column axis. Shape in tuple should be `(num_blocks, rows, cols)`.""" + """Blocks of matrix, concatenated along the column axis. Shape in tuple + should be `([num_block_rows,] rows, cols)`.""" def treedef(self) -> Hashable: return tuple(block.shape for block in self.blocks_concat) + def to_dense(self) -> jax.Array: + """Convert block-row or batched block-rows to dense representation.""" + if self.blocks_concat.ndim == 3: + # Batched block-rows. + (num_block_rows, num_rows, _) = self.blocks_concat.shape + return jax.vmap(SparseBlockRow.to_dense)(self).reshape( + (num_block_rows * num_rows, self.num_cols) + ) + + assert self.blocks_concat.ndim == 2 + num_rows, num_cols_concat = self.blocks_concat.shape + out = jnp.zeros((num_rows, self.num_cols)) + + start_concat_col = 0 + for start_col, block_width in zip(self.start_cols, self.block_num_cols): + end_concat_col = start_concat_col + block_width + out = jax.lax.dynamic_update_slice( + out, + update=self.blocks_concat[:, start_concat_col:end_concat_col], + start_indices=(0, start_col), + ) + start_concat_col = end_concat_col + + assert start_concat_col == num_cols_concat + return out + @jdc.pytree_dataclass class BlockRowSparseMatrix: - block_rows: tuple[MatrixBlockRow, ...] + block_rows: tuple[SparseBlockRow, ...] """Batched block-rows, ordered. Each element in the tuple has a leading axis, which represents consecutive block-rows.""" shape: jdc.Static[tuple[int, int]] @@ -42,9 +87,11 @@ def multiply(self, target: jax.Array) -> jax.Array: del block_rows # Get slices corresponding to nonzero terms in block-row. - assert len(block_row.start_cols) == len(block_row.block_widths) + assert len(block_row.start_cols) == len(block_row.block_num_cols) target_slice_parts = list[jax.Array]() - for start_cols, width in zip(block_row.start_cols, block_row.block_widths): + for start_cols, width in zip( + block_row.start_cols, block_row.block_num_cols + ): assert start_cols.shape == (n_block,) assert isinstance(width, int) slice_part = jax.vmap( @@ -69,6 +116,15 @@ def multiply(self, target: jax.Array) -> jax.Array: result = jnp.concatenate(out_slices, axis=0) return result + def to_dense(self) -> jax.Array: + """Convert to a dense matrix.""" + out = jnp.concatenate( + [block_row.to_dense() for block_row in self.block_rows], + axis=0, + ) + assert out.shape == self.shape + return out + @jdc.pytree_dataclass class SparseCsrCoordinates: