Skip to content

Commit

Permalink
Default to iterative linear solver (#21)
Browse files Browse the repository at this point in the history
* Default to iterative linear solver

* Change default value to example
  • Loading branch information
brentyi authored Oct 16, 2024
1 parent 0ed83aa commit 1613d1d
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 22 deletions.
31 changes: 16 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

[![pyright](https://github.com/brentyi/jaxls/actions/workflows/pyright.yml/badge.svg)](https://github.com/brentyi/jaxls/actions/workflows/pyright.yml)

_status: working! see limitations [here](#limitations)_

**`jaxls`** is a library for nonlinear least squares in JAX.

We provide a factor graph interface for specifying and solving least squares
Expand All @@ -21,10 +19,10 @@ Currently supported:
- Examples provided for SO(2), SO(3), SE(2), and SE(3).
- Nonlinear solvers: Levenberg-Marquardt and Gauss-Newton.
- Linear subproblem solvers:
- Sparse direct with Cholesky / CHOLMOD, on CPU.
- Sparse iterative with Conjugate Gradient.
- Preconditioning: block and point Jacobi.
- Inexact Newton via Eisenstat-Walker.
- Sparse direct with Cholesky / CHOLMOD, on CPU.
- Dense Cholesky for smaller problems.

For the first iteration of this library, written for [IROS 2021](https://github.com/brentyi/dfgo), see
Expand All @@ -37,29 +35,32 @@ and easier to use. For additional references, see inspirations like

### Installation

`jaxls` supports `python>=3.12`.
`jaxls` supports `python>=3.12`:

For Cholesky factorization via CHOLMOD, `scikit-sparse` requires SuiteSparse:
```bash
pip install git+https://github.com/brentyi/jaxls.git
```

**Optional: CHOLMOD dependencies**

By default, we use an iterative linear solver. This requires no extra
dependencies. For some problems, like those with banded matrices, a direct
solver can be much faster.

For Cholesky factorization via CHOLMOD, we rely on SuiteSparse:

```bash
# Option 1: via conda.
conda install conda-forge::suitesparse

# Option 2: via apt.
sudo apt update
sudo apt install -y libsuitesparse-dev
# Option 3: via brew.
brew install suite-sparse
```

Then, from your environment of choice:
You'll also need _scikit-sparse_:

```bash
# Option 1: from git.
pip install git+ssh://[email protected]/brentyi/jaxls.git
# Option 2: editable.
git clone https://github.com/brentyi/jaxls.git
cd jaxls
pip install -e .
pip install scikit-sparse
```

### Pose graph example
Expand Down
4 changes: 2 additions & 2 deletions examples/pose_graph_g2o.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
def main(
g2o_path: pathlib.Path = pathlib.Path(__file__).parent / "data/input_M3500_g2o.g2o",
linear_solver: Literal[
"cholmod", "conjugate_gradient", "dense_cholesky"
] = "cholmod",
"conjugate_gradient", "cholmod", "dense_cholesky"
] = "conjugate_gradient",
) -> None:
# Parse g2o file.
with jaxls.utils.stopwatch("Reading g2o file"):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ dependencies = [
"jaxlib",
"jaxlie>=1.0.0",
"jax_dataclasses>=1.0.0",
"scikit-sparse",
"loguru",
"termcolor",
"tqdm",
Expand All @@ -30,6 +29,7 @@ dependencies = [
[project.optional-dependencies]
dev = [
"pyright>=1.1.308",
"scikit-sparse",
"ruff",
]

Expand Down
4 changes: 2 additions & 2 deletions src/jaxls/_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def solve(
self,
initial_vals: VarValues | None = None,
*,
linear_solver: Literal["cholmod", "conjugate_gradient", "dense_cholesky"]
| ConjugateGradientConfig = "cholmod",
linear_solver: Literal["conjugate_gradient", "cholmod", "dense_cholesky"]
| ConjugateGradientConfig = "conjugate_gradient",
trust_region: TrustRegionConfig | None = TrustRegionConfig(),
termination: TerminationConfig = TerminationConfig(),
verbose: bool = True,
Expand Down
7 changes: 5 additions & 2 deletions src/jaxls/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import jax_dataclasses as jdc
import scipy
import scipy.sparse
import sksparse.cholmod
from jax import numpy as jnp

from jaxls._preconditioning import (
Expand All @@ -22,6 +21,8 @@
from .utils import jax_log

if TYPE_CHECKING:
import sksparse.cholmod

from ._factor_graph import FactorGraph


Expand All @@ -48,6 +49,8 @@ def _cholmod_solve_on_host(
lambd: float | jax.Array,
) -> jax.Array:
"""Solve a linear system using CHOLMOD. Should be called on the host."""
import sksparse.cholmod

# Matrix is transposed when we convert CSR to CSC.
A_T_scipy = scipy.sparse.csc_matrix(
(A.values, A.coords.indices, A.coords.indptr), shape=A.coords.shape[::-1]
Expand Down Expand Up @@ -183,7 +186,7 @@ class NonlinearSolver:
"""Helper class for solving using Gauss-Newton or Levenberg-Marquardt."""

linear_solver: jdc.Static[
Literal["cholmod", "dense_cholesky", "conjugate_gradient"]
Literal["conjugate_gradient", "cholmod", "dense_cholesky"]
]
trust_region: TrustRegionConfig | None
termination: TerminationConfig
Expand Down

0 comments on commit 1613d1d

Please sign in to comment.