Skip to content

Commit

Permalink
nits
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Oct 10, 2024
1 parent daabf6d commit c265641
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

[![pyright](https://github.com/brentyi/jaxls/actions/workflows/pyright.yml/badge.svg)](https://github.com/brentyi/jaxls/actions/workflows/pyright.yml)

_status: working! see limitations [here](#limitations)_
_status: working! see limitations [here](#limitations)_

**`jaxls`** is a library for nonlinear least squares in JAX.

Expand All @@ -21,8 +21,7 @@ Features:
iterative (Jacobi-preconditioned Conjugate Gradient).

Use cases are primarily in least squares problems that are inherently (1) sparse
and (2) inefficient to solve with gradient-based methods. In robotics, these are
ubiquitous across classical approaches to perception, planning, and control.
and (2) inefficient to solve with gradient-based methods.

For the first iteration of this library, written for
[IROS 2021](https://github.com/brentyi/dfgo), see
Expand Down Expand Up @@ -122,6 +121,7 @@ print("Pose 1", solution[pose_vars[1]])
### Limitations

There are many practical features that we don't currently support:

- GPU accelerated Cholesky factorization. (for CHOLMOD we wrap [scikit-sparse](https://scikit-sparse.readthedocs.io/en/latest/), which runs on CPU only)
- Covariance estimation / marginalization.
- Incremental solves.
Expand Down
29 changes: 19 additions & 10 deletions src/jaxls/_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,11 @@ def _sort_key(x: Any) -> str:
stacked_factor: Factor = jax.tree.map(
lambda *args: jnp.concatenate(args, axis=0), *group
)
stacked_factor_expanded: _AnalyzedFactor = jax.vmap(_AnalyzedFactor._make)(
stacked_factor
)
stacked_factors.append(stacked_factor_expanded)
stacked_factor_analyzed: _AnalyzedFactor = jax.vmap(
_AnalyzedFactor._analyze
)(stacked_factor)

stacked_factors.append(stacked_factor_analyzed)
factor_counts.append(count_from_group[group_key])

logger.info(
Expand All @@ -321,6 +322,14 @@ def _sort_key(x: Any) -> str:
stacked_factors[-1].compute_residual.__name__,
)

# Check that all variables are present.
for var_type in stacked_factor_analyzed.sorted_ids_from_var_type.keys():
assert var_type in tangent_start_from_var_type, (
f"Found variable of type {var_type} as input"
f" to factor with residual {stacked_factor_analyzed.compute_residual},"
" but variable type is missing from `FactorGraph.make()`."
)

# Compute Jacobian coordinates.
#
# These should be N pairs of (row, col) indices, where rows correspond to
Expand All @@ -332,24 +341,24 @@ def _sort_key(x: Any) -> str:
sorted_ids_from_var_type=sorted_ids_from_var_type,
tangent_start_from_var_type=tangent_start_from_var_type,
)
)(stacked_factor_expanded)
)(stacked_factor_analyzed)
assert (
rows.shape
== cols.shape
== (
count_from_group[group_key],
stacked_factor_expanded.residual_dim,
stacked_factor_analyzed.residual_dim,
rows.shape[-1],
)
)
rows = rows + (
jnp.arange(count_from_group[group_key])[:, None, None]
* stacked_factor_expanded.residual_dim
* stacked_factor_analyzed.residual_dim
)
rows = rows + residual_dim_sum
jac_coords.append((rows.flatten(), cols.flatten()))
residual_dim_sum += (
stacked_factor_expanded.residual_dim * count_from_group[group_key]
stacked_factor_analyzed.residual_dim * count_from_group[group_key]
)

jac_coords_coo: SparseCooCoordinates = SparseCooCoordinates(
Expand Down Expand Up @@ -425,7 +434,7 @@ class _AnalyzedFactor[*Args](Factor[*Args]):

@staticmethod
@jdc.jit
def _make[*Args_](factor: Factor[*Args_]) -> _AnalyzedFactor[*Args_]:
def _analyze[*Args_](factor: Factor[*Args_]) -> _AnalyzedFactor[*Args_]:
"""Construct a factor for our factor graph."""

compute_residual = factor.compute_residual
Expand Down Expand Up @@ -456,7 +465,7 @@ def traverse_args(current: Any, variables: list[Var]) -> list[Var]:
() if isinstance(var.id, int) else var.id.shape
) == batch_axes, "Batch axes of variables do not match."
if len(batch_axes) == 1:
return jax.vmap(_AnalyzedFactor._make)(factor)
return jax.vmap(_AnalyzedFactor._analyze)(factor)

# Cache the residual dimension for this factor.
dummy_vals = jax.eval_shape(VarValues.make, variables)
Expand Down

0 comments on commit c265641

Please sign in to comment.