diff --git a/README.md b/README.md index 51a8e1f..1a4356a 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,6 @@ [![pyright](https://github.com/brentyi/jaxls/actions/workflows/pyright.yml/badge.svg)](https://github.com/brentyi/jaxls/actions/workflows/pyright.yml) -_status: working! see limitations [here](#limitations)_ - **`jaxls`** is a library for nonlinear least squares in JAX. We provide a factor graph interface for specifying and solving least squares @@ -21,10 +19,10 @@ Currently supported: - Examples provided for SO(2), SO(3), SE(2), and SE(3). - Nonlinear solvers: Levenberg-Marquardt and Gauss-Newton. - Linear subproblem solvers: - - Sparse direct with Cholesky / CHOLMOD, on CPU. - Sparse iterative with Conjugate Gradient. - Preconditioning: block and point Jacobi. - Inexact Newton via Eisenstat-Walker. + - Sparse direct with Cholesky / CHOLMOD, on CPU. - Dense Cholesky for smaller problems. For the first iteration of this library, written for [IROS 2021](https://github.com/brentyi/dfgo), see @@ -37,29 +35,32 @@ and easier to use. For additional references, see inspirations like ### Installation -`jaxls` supports `python>=3.12`. +`jaxls` supports `python>=3.12`: -For Cholesky factorization via CHOLMOD, `scikit-sparse` requires SuiteSparse: +```bash +pip install git+https://github.com/brentyi/jaxls.git +``` + +**Optional: CHOLMOD dependencies** + +By default, we use an iterative linear solver. This requires no extra +dependencies. For some problems, like those with banded matrices, a direct +solver can be much faster. + +For Cholesky factorization via CHOLMOD, we rely on SuiteSparse: ```bash # Option 1: via conda. conda install conda-forge::suitesparse + # Option 2: via apt. -sudo apt update sudo apt install -y libsuitesparse-dev -# Option 3: via brew. -brew install suite-sparse ``` -Then, from your environment of choice: +You'll also need _scikit-sparse_: ```bash -# Option 1: from git. -pip install git+ssh://git@github.com/brentyi/jaxls.git -# Option 2: editable. -git clone https://github.com/brentyi/jaxls.git -cd jaxls -pip install -e . +pip install scikit-sparse ``` ### Pose graph example diff --git a/examples/pose_graph_g2o.py b/examples/pose_graph_g2o.py index 6da94cb..4b2568d 100755 --- a/examples/pose_graph_g2o.py +++ b/examples/pose_graph_g2o.py @@ -20,8 +20,8 @@ def main( g2o_path: pathlib.Path = pathlib.Path(__file__).parent / "data/input_M3500_g2o.g2o", linear_solver: Literal[ - "cholmod", "conjugate_gradient", "dense_cholesky" - ] = "cholmod", + "conjugate_gradient", "cholmod", "dense_cholesky" + ] = "conjugate_gradient", ) -> None: # Parse g2o file. with jaxls.utils.stopwatch("Reading g2o file"): diff --git a/pyproject.toml b/pyproject.toml index fe0d900..ac1ec9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,6 @@ dependencies = [ "jaxlib", "jaxlie>=1.0.0", "jax_dataclasses>=1.0.0", - "scikit-sparse", "loguru", "termcolor", "tqdm", @@ -30,6 +29,7 @@ dependencies = [ [project.optional-dependencies] dev = [ "pyright>=1.1.308", + "scikit-sparse", "ruff", ] diff --git a/src/jaxls/_factor_graph.py b/src/jaxls/_factor_graph.py index 000f789..0581648 100644 --- a/src/jaxls/_factor_graph.py +++ b/src/jaxls/_factor_graph.py @@ -56,8 +56,8 @@ def solve( self, initial_vals: VarValues | None = None, *, - linear_solver: Literal["cholmod", "conjugate_gradient", "dense_cholesky"] - | ConjugateGradientConfig = "cholmod", + linear_solver: Literal["conjugate_gradient", "cholmod", "dense_cholesky"] + | ConjugateGradientConfig = "conjugate_gradient", trust_region: TrustRegionConfig | None = TrustRegionConfig(), termination: TerminationConfig = TerminationConfig(), verbose: bool = True, diff --git a/src/jaxls/_solvers.py b/src/jaxls/_solvers.py index 664b7ed..1971173 100644 --- a/src/jaxls/_solvers.py +++ b/src/jaxls/_solvers.py @@ -9,7 +9,6 @@ import jax_dataclasses as jdc import scipy import scipy.sparse -import sksparse.cholmod from jax import numpy as jnp from jaxls._preconditioning import ( @@ -22,6 +21,8 @@ from .utils import jax_log if TYPE_CHECKING: + import sksparse.cholmod + from ._factor_graph import FactorGraph @@ -48,6 +49,8 @@ def _cholmod_solve_on_host( lambd: float | jax.Array, ) -> jax.Array: """Solve a linear system using CHOLMOD. Should be called on the host.""" + import sksparse.cholmod + # Matrix is transposed when we convert CSR to CSC. A_T_scipy = scipy.sparse.csc_matrix( (A.values, A.coords.indices, A.coords.indptr), shape=A.coords.shape[::-1] @@ -183,7 +186,7 @@ class NonlinearSolver: """Helper class for solving using Gauss-Newton or Levenberg-Marquardt.""" linear_solver: jdc.Static[ - Literal["cholmod", "dense_cholesky", "conjugate_gradient"] + Literal["conjugate_gradient", "cholmod", "dense_cholesky"] ] trust_region: TrustRegionConfig | None termination: TerminationConfig