Skip to content

Commit

Permalink
Linear solvers refactor + add dense Cholesky + fix trust region termi…
Browse files Browse the repository at this point in the history
…nation (#20)

* Linear solver refactor + support dense Cholesky

* README

* Fixes

* Fix

* Fix trust region
  • Loading branch information
brentyi authored Oct 13, 2024
1 parent 6a5c110 commit 9a53fa4
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 126 deletions.
30 changes: 15 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,25 @@ problems. We accelerate optimization by analyzing the structure of graphs:
repeated factor and variable types are vectorized, and the sparsity of adjacency
in the graph is translated into sparse matrix operations.

Use cases are primarily in least squares problems that are (1) sparse and (2)
inefficient to solve with gradient-based methods.

Currently supported:

- Automatic sparse Jacobians.
- Optimization on manifolds; SO(2), SO(3), SE(2), and SE(3) implementations
included.
- Optimization on manifolds.
- Examples provided for SO(2), SO(3), SE(2), and SE(3).
- Nonlinear solvers: Levenberg-Marquardt and Gauss-Newton.
- Direct linear solves via sparse Cholesky / CHOLMOD, on CPU.
- Iterative linear solves via Conjugate Gradient.
- Preconditioning: block and point Jacobi.
- Inexact Newton via Eisenstat-Walker.

Use cases are primarily in least squares problems that are inherently (1)
sparse and (2) inefficient to solve with gradient-based methods. These are
common in robotics.

For the first iteration of this library, written for
[IROS 2021](https://github.com/brentyi/dfgo), see
[jaxfg](https://github.com/brentyi/jaxfg). `jaxls` is a rewrite that aims to be
faster and easier to use. For additional references, see inspirations like
- Linear subproblem solvers:
- Sparse direct with Cholesky / CHOLMOD, on CPU.
- Sparse iterative with Conjugate Gradient.
- Preconditioning: block and point Jacobi.
- Inexact Newton via Eisenstat-Walker.
- Dense Cholesky for smaller problems.

For the first iteration of this library, written for [IROS 2021](https://github.com/brentyi/dfgo), see
[jaxfg](https://github.com/brentyi/jaxfg). `jaxls` is a rewrite that is faster
and easier to use. For additional references, see inspirations like
[GTSAM](https://gtsam.org/), [Ceres Solver](http://ceres-solver.org/),
[minisam](https://github.com/dongjing3309/minisam),
[SwiftFusion](https://github.com/borglab/SwiftFusion),
Expand Down
14 changes: 5 additions & 9 deletions examples/pose_graph_g2o.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,15 @@

def main(
g2o_path: pathlib.Path = pathlib.Path(__file__).parent / "data/input_M3500_g2o.g2o",
linear_solver: Literal["cholmod", "pcg"] = "cholmod",
linear_solver: Literal[
"cholmod", "conjugate_gradient", "dense_cholesky"
] = "cholmod",
) -> None:
# Parse g2o file.
with jaxls.utils.stopwatch("Reading g2o file"):
g2o: _g2o_utils.G2OData = _g2o_utils.parse_g2o(g2o_path)
jax.block_until_ready(g2o)

# Get linear solver.
linear_solver_config = {
"cholmod": jaxls.CholmodLinearSolver,
"pcg": jaxls.ConjugateGradientLinearSolver,
}[linear_solver]()

# Making graph.
with jaxls.utils.stopwatch("Making graph"):
graph = jaxls.FactorGraph.make(factors=g2o.factors, variables=g2o.pose_vars)
Expand All @@ -47,12 +43,12 @@ def main(

with jaxls.utils.stopwatch("Running solve"):
solution_vals = graph.solve(
initial_vals, trust_region=None, linear_solver=linear_solver_config
initial_vals, trust_region=None, linear_solver=linear_solver
)

with jaxls.utils.stopwatch("Running solve (again)"):
solution_vals = graph.solve(
initial_vals, trust_region=None, linear_solver=linear_solver_config
initial_vals, trust_region=None, linear_solver=linear_solver
)

# Plot
Expand Down
2 changes: 1 addition & 1 deletion examples/pose_graph_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
graph = jaxls.FactorGraph.make(factors, vars)

# Solve the optimization problem.
solution = graph.solve(linear_solver=jaxls.ConjugateGradientLinearSolver())
solution = graph.solve()
print("All solutions", solution)
print("Pose 0", solution[vars[0]])
print("Pose 1", solution[vars[1]])
3 changes: 1 addition & 2 deletions src/jaxls/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from ._lie_group_variables import SE3Var as SE3Var
from ._lie_group_variables import SO2Var as SO2Var
from ._lie_group_variables import SO3Var as SO3Var
from ._solvers import CholmodLinearSolver as CholmodLinearSolver
from ._solvers import ConjugateGradientLinearSolver as ConjugateGradientLinearSolver
from ._solvers import ConjugateGradientConfig as ConjugateGradientConfig
from ._solvers import TerminationConfig as TerminationConfig
from ._solvers import TrustRegionConfig as TrustRegionConfig
from ._variables import Var as Var
Expand Down
33 changes: 22 additions & 11 deletions src/jaxls/_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
from typing_extensions import deprecated

from ._solvers import (
CholmodLinearSolver,
ConjugateGradientLinearSolver,
ConjugateGradientConfig,
NonlinearSolver,
TerminationConfig,
TrustRegionConfig,
)
from ._sparse_matrices import (
BlockRowSparseMatrix,
MatrixBlockRow,
SparseBlockRow,
SparseCooCoordinates,
SparseCsrCoordinates,
)
Expand Down Expand Up @@ -56,8 +55,9 @@ class FactorGraph:
def solve(
self,
initial_vals: VarValues | None = None,
linear_solver: CholmodLinearSolver
| ConjugateGradientLinearSolver = CholmodLinearSolver(),
*,
linear_solver: Literal["cholmod", "conjugate_gradient", "dense_cholesky"]
| ConjugateGradientConfig = "cholmod",
trust_region: TrustRegionConfig | None = TrustRegionConfig(),
termination: TerminationConfig = TerminationConfig(),
verbose: bool = True,
Expand All @@ -68,7 +68,19 @@ def solve(
initial_vals = VarValues.make(
var_type(ids) for var_type, ids in self.sorted_ids_from_var_type.items()
)
solver = NonlinearSolver(linear_solver, trust_region, termination, verbose)

# In our internal API, linear_solver needs to always be a string. The
# conjugate gradient config is a separate field. This is more
# convenient to implement, because then the former can be static while
# the latter is a pytree.
conjugate_gradient_config = None
if isinstance(linear_solver, ConjugateGradientConfig):
conjugate_gradient_config = linear_solver
linear_solver = "conjugate_gradient"

solver = NonlinearSolver(
linear_solver, trust_region, termination, conjugate_gradient_config, verbose
)
return solver.solve(graph=self, initial_vals=initial_vals)

def compute_residual_vector(self, vals: VarValues) -> jax.Array:
Expand All @@ -84,7 +96,7 @@ def compute_residual_vector(self, vals: VarValues) -> jax.Array:
return jnp.concatenate(residual_slices, axis=0)

def _compute_jac_values(self, vals: VarValues) -> BlockRowSparseMatrix:
block_rows = list[MatrixBlockRow]()
block_rows = list[SparseBlockRow]()
residual_offset = 0

for factor in self.stacked_factors:
Expand Down Expand Up @@ -151,11 +163,10 @@ def compute_jac_with_perturb(factor: _AnalyzedFactor) -> jax.Array:
assert stacked_jac.shape[-1] == stacked_jac_start_col

block_rows.append(
MatrixBlockRow(
start_row=jnp.arange(num_factor) * factor.residual_dim
+ residual_offset,
SparseBlockRow(
num_cols=self.tangent_dim,
start_cols=tuple(start_cols),
block_widths=tuple(block_widths),
block_num_cols=tuple(block_widths),
blocks_concat=stacked_jac,
)
)
Expand Down
6 changes: 3 additions & 3 deletions src/jaxls/_preconditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def make_point_jacobi_precoditioner(
block_l2_cols = jnp.sum(block_row.blocks_concat**2, axis=1).flatten()
indices = jnp.concatenate(
[
(start_col[:, None] + jnp.arange(width)[None, :])
for start_col, width in zip(
block_row.start_cols, block_row.block_widths
(start_col[:, None] + jnp.arange(block_cols)[None, :])
for start_col, block_cols in zip(
block_row.start_cols, block_row.block_num_cols
)
],
axis=1,
Expand Down
Loading

0 comments on commit 9a53fa4

Please sign in to comment.