Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default to iterative linear solver #21

Merged
merged 2 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

[![pyright](https://github.com/brentyi/jaxls/actions/workflows/pyright.yml/badge.svg)](https://github.com/brentyi/jaxls/actions/workflows/pyright.yml)

_status: working! see limitations [here](#limitations)_

**`jaxls`** is a library for nonlinear least squares in JAX.

We provide a factor graph interface for specifying and solving least squares
Expand All @@ -21,10 +19,10 @@ Currently supported:
- Examples provided for SO(2), SO(3), SE(2), and SE(3).
- Nonlinear solvers: Levenberg-Marquardt and Gauss-Newton.
- Linear subproblem solvers:
- Sparse direct with Cholesky / CHOLMOD, on CPU.
- Sparse iterative with Conjugate Gradient.
- Preconditioning: block and point Jacobi.
- Inexact Newton via Eisenstat-Walker.
- Sparse direct with Cholesky / CHOLMOD, on CPU.
- Dense Cholesky for smaller problems.

For the first iteration of this library, written for [IROS 2021](https://github.com/brentyi/dfgo), see
Expand All @@ -37,29 +35,32 @@ and easier to use. For additional references, see inspirations like

### Installation

`jaxls` supports `python>=3.12`.
`jaxls` supports `python>=3.12`:

For Cholesky factorization via CHOLMOD, `scikit-sparse` requires SuiteSparse:
```bash
pip install git+https://github.com/brentyi/jaxls.git
```

**Optional: CHOLMOD dependencies**

By default, we use an iterative linear solver. This requires no extra
dependencies. For some problems, like those with banded matrices, a direct
solver can be much faster.

For Cholesky factorization via CHOLMOD, we rely on SuiteSparse:

```bash
# Option 1: via conda.
conda install conda-forge::suitesparse

# Option 2: via apt.
sudo apt update
sudo apt install -y libsuitesparse-dev
# Option 3: via brew.
brew install suite-sparse
```

Then, from your environment of choice:
You'll also need _scikit-sparse_:

```bash
# Option 1: from git.
pip install git+ssh://[email protected]/brentyi/jaxls.git
# Option 2: editable.
git clone https://github.com/brentyi/jaxls.git
cd jaxls
pip install -e .
pip install scikit-sparse
```

### Pose graph example
Expand Down
4 changes: 2 additions & 2 deletions examples/pose_graph_g2o.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
def main(
g2o_path: pathlib.Path = pathlib.Path(__file__).parent / "data/input_M3500_g2o.g2o",
linear_solver: Literal[
"cholmod", "conjugate_gradient", "dense_cholesky"
] = "cholmod",
"conjugate_gradient", "cholmod", "dense_cholesky"
] = "conjugate_gradient",
) -> None:
# Parse g2o file.
with jaxls.utils.stopwatch("Reading g2o file"):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ dependencies = [
"jaxlib",
"jaxlie>=1.0.0",
"jax_dataclasses>=1.0.0",
"scikit-sparse",
"loguru",
"termcolor",
"tqdm",
Expand All @@ -30,6 +29,7 @@ dependencies = [
[project.optional-dependencies]
dev = [
"pyright>=1.1.308",
"scikit-sparse",
"ruff",
]

Expand Down
4 changes: 2 additions & 2 deletions src/jaxls/_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def solve(
self,
initial_vals: VarValues | None = None,
*,
linear_solver: Literal["cholmod", "conjugate_gradient", "dense_cholesky"]
| ConjugateGradientConfig = "cholmod",
linear_solver: Literal["conjugate_gradient", "cholmod", "dense_cholesky"]
| ConjugateGradientConfig = "conjugate_gradient",
trust_region: TrustRegionConfig | None = TrustRegionConfig(),
termination: TerminationConfig = TerminationConfig(),
verbose: bool = True,
Expand Down
7 changes: 5 additions & 2 deletions src/jaxls/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import jax_dataclasses as jdc
import scipy
import scipy.sparse
import sksparse.cholmod
from jax import numpy as jnp

from jaxls._preconditioning import (
Expand All @@ -22,6 +21,8 @@
from .utils import jax_log

if TYPE_CHECKING:
import sksparse.cholmod

from ._factor_graph import FactorGraph


Expand All @@ -48,6 +49,8 @@ def _cholmod_solve_on_host(
lambd: float | jax.Array,
) -> jax.Array:
"""Solve a linear system using CHOLMOD. Should be called on the host."""
import sksparse.cholmod

# Matrix is transposed when we convert CSR to CSC.
A_T_scipy = scipy.sparse.csc_matrix(
(A.values, A.coords.indices, A.coords.indptr), shape=A.coords.shape[::-1]
Expand Down Expand Up @@ -183,7 +186,7 @@ class NonlinearSolver:
"""Helper class for solving using Gauss-Newton or Levenberg-Marquardt."""

linear_solver: jdc.Static[
Literal["cholmod", "dense_cholesky", "conjugate_gradient"]
Literal["conjugate_gradient", "cholmod", "dense_cholesky"]
]
trust_region: TrustRegionConfig | None
termination: TerminationConfig
Expand Down
Loading