Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Oct 11, 2024
1 parent 7bef329 commit 2a7f0ba
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
18 changes: 8 additions & 10 deletions src/jaxls/_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def compute_jac_with_perturb(factor: _AnalyzedFactor) -> jax.Array:
)
)(jnp.zeros((val_subset._get_tangent_dim(),)))

# Compute Jacobian for each factor.
stacked_jac = jax.vmap(compute_jac_with_perturb)(factor)
(num_factor,) = factor._get_batch_axes()
assert stacked_jac.shape == (
Expand All @@ -125,8 +126,8 @@ def compute_jac_with_perturb(factor: _AnalyzedFactor) -> jax.Array:
)
jac_vals.append(stacked_jac.flatten())

# Compute block-row representation for sparse Jacobian.
stacked_jac_start_col = 0

start_cols = list[jax.Array]()
blocks = list[jax.Array]()
for var_type, ids in self.tangent_ordering.ordered_dict_items(
Expand All @@ -135,12 +136,9 @@ def compute_jac_with_perturb(factor: _AnalyzedFactor) -> jax.Array:
):
(num_factor_, num_vars) = ids.shape
assert num_factor == num_factor_
stacked_jac_end_col = (
stacked_jac_start_col + num_vars * var_type.tangent_dim
)

# Get one block for each variable.
for var_idx in range(ids.shape[-1]):
print(f"{ids.shape=}")
start_cols.append(
jnp.searchsorted(
self.sorted_ids_from_var_type[var_type], ids[..., var_idx]
Expand All @@ -149,16 +147,16 @@ def compute_jac_with_perturb(factor: _AnalyzedFactor) -> jax.Array:
+ self.tangent_start_from_var_type[var_type]
)
assert start_cols[-1].shape == (num_factor_,)
subblock_start = (
stacked_jac_start_col + var_idx * var_type.tangent_dim
)
block_start = stacked_jac_start_col + var_idx * var_type.tangent_dim
blocks.append(
stacked_jac[
..., subblock_start : subblock_start + var_type.tangent_dim
..., block_start : block_start + var_type.tangent_dim
]
)

stacked_jac_start_col = stacked_jac_end_col
stacked_jac_start_col = (
stacked_jac_start_col + num_vars * var_type.tangent_dim
)

assert stacked_jac.shape[-1] == stacked_jac_start_col

Expand Down
1 change: 1 addition & 0 deletions src/jaxls/_sparse_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class BlockRowSparseMatrix:
"""Shape of matrix."""

def multiply(self, target: jax.Array) -> jax.Array:
"""Sparse-dense multiplication."""
assert target.ndim == 1

def multiply_one_block_row(
Expand Down

0 comments on commit 2a7f0ba

Please sign in to comment.