jaxls
is a library for solving sparse NLLS and IRLS problems in JAX.
These are common in classical robotics and computer vision, as well as for MAP estimation with Gaussian likelihoods.
To install on Python 3.12+:
pip install git+https://github.com/brentyi/jaxls.git
We provide a factor graph interface for specifying and solving least squares
problems. jaxls
takes advantage of structure in graphs: repeated factor
and variable types are vectorized, and sparsity of adjacency is translated into
sparse matrix operations.
Supported:
- Automatic sparse Jacobians.
- Optimization on manifolds.
- Examples provided for SO(2), SO(3), SE(2), and SE(3).
- Nonlinear solvers: Levenberg-Marquardt and Gauss-Newton.
- Linear subproblem solvers:
- Sparse iterative with Conjugate Gradient.
- Preconditioning: block and point Jacobi.
- Inexact Newton via Eisenstat-Walker.
- Recommended for most problems.
- Dense Cholesky.
- Fast for small problems.
- Sparse Cholesky, on CPU. (CHOLMOD)
- Sparse iterative with Conjugate Gradient.
jaxls
borrows heavily from libraries like
GTSAM, Ceres Solver,
minisam,
SwiftFusion,
and g2o.
import jaxls
import jaxlie
Defining variables. Each variable is given an integer ID. They don't need to be contiguous.
pose_vars = [jaxls.SE2Var(0), jaxls.SE2Var(1)]
Defining factors. Factors are defined using a callable cost function and a set of arguments.
# Factors take two arguments:
# - A callable with signature `(jaxls.VarValues, *Args) -> jax.Array`.
# - A tuple of arguments: the type should be `tuple[*Args]`.
#
# All arguments should be PyTree structures. Variable types within the PyTree
# will be automatically detected.
factors = [
# Cost on pose 0.
jaxls.Factor(
lambda vals, var, init: (vals[var] @ init.inverse()).log(),
(pose_vars[0], jaxlie.SE2.from_xy_theta(0.0, 0.0, 0.0)),
),
# Cost on pose 1.
jaxls.Factor(
lambda vals, var, init: (vals[var] @ init.inverse()).log(),
(pose_vars[1], jaxlie.SE2.from_xy_theta(2.0, 0.0, 0.0)),
),
# Cost between poses.
jaxls.Factor(
lambda vals, var0, var1, delta: (
(vals[var0].inverse() @ vals[var1]) @ delta.inverse()
).log(),
(pose_vars[0], pose_vars[1], jaxlie.SE2.from_xy_theta(1.0, 0.0, 0.0)),
),
]
Factors with similar structure, like the first two in this example, will be vectorized under-the-hood.
Batched inputs can also be manually constructed, and are detected by inspecting the shape of variable ID arrays in the input. Just add a leading batch axis to all elements in the arguments tuple.
Solving optimization problems. To set up the optimization problem, solve it, and print solutions:
graph = jaxls.FactorGraph.make(factors, pose_vars)
solution = graph.solve()
print("All solutions", solution)
print("Pose 0", solution[pose_vars[0]])
print("Pose 1", solution[pose_vars[1]])
By default, we use an iterative linear solver. This requires no extra dependencies. For some problems, like those with banded matrices, a direct solver can be much faster.
For Cholesky factorization via CHOLMOD, we rely on SuiteSparse:
# Option 1: via conda.
conda install conda-forge::suitesparse
# Option 2: via apt.
sudo apt install -y libsuitesparse-dev
You'll also need scikit-sparse:
pip install scikit-sparse