Skip to content

Commit

Permalink
Improve batch axis support
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Sep 27, 2024
1 parent b7f439b commit f751dc7
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,21 @@ faster and easier to use. For additional references, see inspirations like
For Cholesky factorization via CHOLMOD, `scikit-sparse` requires SuiteSparse:

```bash
# Via conda.
# Option 1: via conda.
conda install conda-forge::suitesparse
# Via apt.
# Option 2: via apt.
sudo apt update
sudo apt install -y libsuitesparse-dev
# Via brew.
# Option 3: via brew.
brew install suite-sparse
```

Then, from your environment of choice:

```bash
# Option 1: from git.
pip install git+ssh://[email protected]/brentyi/jaxls.git
# Option 2: editable.
git clone https://github.com/brentyi/jaxls.git
cd jaxls
pip install -e .
Expand Down
32 changes: 30 additions & 2 deletions src/jaxls/_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,25 @@ def make(

factors = tuple(factors)
variables = tuple(variables)

# We're assuming no more than 1 batch axis.
num_factors = 0
for f in factors:
assert len(f._get_batch_axes()) in (0, 1)
num_factors += (
1 if len(f._get_batch_axes()) == 0 else f._get_batch_axes()[0]
)

num_variables = 0
for v in variables:
assert isinstance(v.id, int) or len(v.id.shape) in (0, 1)
num_variables += (
1 if isinstance(v.id, int) or v.id.shape == () else v.id.shape[0]
)
logger.info(
"Building graph with {} factors and {} variables.",
len(factors),
len(variables),
num_factors,
num_variables,
)

# Start by grouping our factors and grabbing a list of (ordered!) variables
Expand Down Expand Up @@ -323,6 +338,19 @@ def traverse_args(current: Any, variables: list[Var]) -> list[Var]:
variables = tuple(traverse_args(args, []))
assert len(variables) > 0

# Support batch axis.
if not isinstance(variables[0].id, int):
batch_axes = variables[0].id.shape
assert len(batch_axes) in (0, 1)
for var in variables[1:]:
assert (
() if isinstance(var.id, int) else var.id.shape
) == batch_axes, "Batch axes of variables do not match."
if len(batch_axes) == 1:
return jax.vmap(Factor._make_impl, in_axes=(None, 0, None))(
compute_residual, args, jac_mode
)

# Cache the residual dimension for this factor.
residual_dim_cache_key = (
compute_residual,
Expand Down

0 comments on commit f751dc7

Please sign in to comment.