diff --git a/README.md b/README.md index 7264c0d..b438b12 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![pyright](https://github.com/brentyi/jaxls/actions/workflows/pyright.yml/badge.svg)](https://github.com/brentyi/jaxls/actions/workflows/pyright.yml) -_status: working! see limitations [here](#limitations)_ +_status: working! see limitations [here](#limitations)_ **`jaxls`** is a library for nonlinear least squares in JAX. @@ -21,8 +21,7 @@ Features: iterative (Jacobi-preconditioned Conjugate Gradient). Use cases are primarily in least squares problems that are inherently (1) sparse -and (2) inefficient to solve with gradient-based methods. In robotics, these are -ubiquitous across classical approaches to perception, planning, and control. +and (2) inefficient to solve with gradient-based methods. For the first iteration of this library, written for [IROS 2021](https://github.com/brentyi/dfgo), see @@ -122,6 +121,7 @@ print("Pose 1", solution[pose_vars[1]]) ### Limitations There are many practical features that we don't currently support: + - GPU accelerated Cholesky factorization. (for CHOLMOD we wrap [scikit-sparse](https://scikit-sparse.readthedocs.io/en/latest/), which runs on CPU only) - Covariance estimation / marginalization. - Incremental solves. diff --git a/src/jaxls/_factor_graph.py b/src/jaxls/_factor_graph.py index 77bd624..287f467 100644 --- a/src/jaxls/_factor_graph.py +++ b/src/jaxls/_factor_graph.py @@ -308,10 +308,11 @@ def _sort_key(x: Any) -> str: stacked_factor: Factor = jax.tree.map( lambda *args: jnp.concatenate(args, axis=0), *group ) - stacked_factor_expanded: _AnalyzedFactor = jax.vmap(_AnalyzedFactor._make)( - stacked_factor - ) - stacked_factors.append(stacked_factor_expanded) + stacked_factor_analyzed: _AnalyzedFactor = jax.vmap( + _AnalyzedFactor._analyze + )(stacked_factor) + + stacked_factors.append(stacked_factor_analyzed) factor_counts.append(count_from_group[group_key]) logger.info( @@ -321,6 +322,14 @@ def _sort_key(x: Any) -> str: stacked_factors[-1].compute_residual.__name__, ) + # Check that all variables are present. + for var_type in stacked_factor_analyzed.sorted_ids_from_var_type.keys(): + assert var_type in tangent_start_from_var_type, ( + f"Found variable of type {var_type} as input" + f" to factor with residual {stacked_factor_analyzed.compute_residual}," + " but variable type is missing from `FactorGraph.make()`." + ) + # Compute Jacobian coordinates. # # These should be N pairs of (row, col) indices, where rows correspond to @@ -332,24 +341,24 @@ def _sort_key(x: Any) -> str: sorted_ids_from_var_type=sorted_ids_from_var_type, tangent_start_from_var_type=tangent_start_from_var_type, ) - )(stacked_factor_expanded) + )(stacked_factor_analyzed) assert ( rows.shape == cols.shape == ( count_from_group[group_key], - stacked_factor_expanded.residual_dim, + stacked_factor_analyzed.residual_dim, rows.shape[-1], ) ) rows = rows + ( jnp.arange(count_from_group[group_key])[:, None, None] - * stacked_factor_expanded.residual_dim + * stacked_factor_analyzed.residual_dim ) rows = rows + residual_dim_sum jac_coords.append((rows.flatten(), cols.flatten())) residual_dim_sum += ( - stacked_factor_expanded.residual_dim * count_from_group[group_key] + stacked_factor_analyzed.residual_dim * count_from_group[group_key] ) jac_coords_coo: SparseCooCoordinates = SparseCooCoordinates( @@ -425,7 +434,7 @@ class _AnalyzedFactor[*Args](Factor[*Args]): @staticmethod @jdc.jit - def _make[*Args_](factor: Factor[*Args_]) -> _AnalyzedFactor[*Args_]: + def _analyze[*Args_](factor: Factor[*Args_]) -> _AnalyzedFactor[*Args_]: """Construct a factor for our factor graph.""" compute_residual = factor.compute_residual @@ -456,7 +465,7 @@ def traverse_args(current: Any, variables: list[Var]) -> list[Var]: () if isinstance(var.id, int) else var.id.shape ) == batch_axes, "Batch axes of variables do not match." if len(batch_axes) == 1: - return jax.vmap(_AnalyzedFactor._make)(factor) + return jax.vmap(_AnalyzedFactor._analyze)(factor) # Cache the residual dimension for this factor. dummy_vals = jax.eval_shape(VarValues.make, variables)