Skip to content

Commit

Permalink
Scatter-free block-sparse matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Oct 11, 2024
1 parent a8a508e commit a87c219
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 98 deletions.
85 changes: 42 additions & 43 deletions src/jaxls/_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
TrustRegionConfig,
)
from ._sparse_matrices import (
BlockSparseMatrix,
MatrixBlock,
BlockRowSparseMatrix,
MatrixBlockRow,
SparseCooCoordinates,
SparseCsrCoordinates,
)
Expand Down Expand Up @@ -85,9 +85,9 @@ def compute_residual_vector(self, vals: VarValues) -> jax.Array:

def _compute_jac_values(
self, vals: VarValues
) -> tuple[jax.Array, BlockSparseMatrix]:
) -> tuple[jax.Array, BlockRowSparseMatrix]:
jac_vals = []
blocks = dict[tuple[int, int], list[MatrixBlock]]()
block_rows = list[MatrixBlockRow]()
residual_offset = 0

for factor in self.stacked_factors:
Expand Down Expand Up @@ -125,58 +125,57 @@ def compute_jac_with_perturb(factor: _AnalyzedFactor) -> jax.Array:
)
jac_vals.append(stacked_jac.flatten())

start_col = 0
stacked_jac_start_col = 0

start_cols = list[jax.Array]()
blocks = list[jax.Array]()
for var_type, ids in self.tangent_ordering.ordered_dict_items(
# This ordering shouldn't actually matter!
factor.sorted_ids_from_var_type
):
block_shape = (factor.residual_dim, var_type.tangent_dim)
(num_factor_, num_vars) = ids.shape
assert num_factor == num_factor_
end_col = start_col + num_vars * var_type.tangent_dim

block_vals = jnp.moveaxis(
stacked_jac[:, :, start_col:end_col].reshape(
(
num_factor_,
factor.residual_dim,
num_vars,
var_type.tangent_dim,
)
),
2,
1,
).reshape(
(num_factor_ * num_vars, factor.residual_dim, var_type.tangent_dim)
stacked_jac_end_col = (
stacked_jac_start_col + num_vars * var_type.tangent_dim
)
blocks.setdefault(block_shape, []).append(
MatrixBlock(
start_row=residual_offset
+ jnp.repeat(
jnp.arange(num_factor_) * factor.residual_dim, num_vars
),
start_col=(
jnp.searchsorted(
self.sorted_ids_from_var_type[var_type], ids.flatten()
)
* var_type.tangent_dim
+ self.tangent_start_from_var_type[var_type]
),
values=block_vals,

for var_idx in range(ids.shape[-1]):
print(f"{ids.shape=}")
start_cols.append(
jnp.searchsorted(
self.sorted_ids_from_var_type[var_type], ids[..., var_idx]
)
* var_type.tangent_dim
+ self.tangent_start_from_var_type[var_type]
)
)
start_col = end_col
assert start_cols[-1].shape == (num_factor_,)
subblock_start = (
stacked_jac_start_col + var_idx * var_type.tangent_dim
)
blocks.append(
stacked_jac[
..., subblock_start : subblock_start + var_type.tangent_dim
]
)

stacked_jac_start_col = stacked_jac_end_col

assert stacked_jac.shape[-1] == stacked_jac_start_col

assert stacked_jac.shape[-1] == start_col
block_rows.append(
MatrixBlockRow(
start_row=jnp.arange(num_factor) * factor.residual_dim
+ residual_offset,
start_cols=tuple(start_cols),
blocks=tuple(blocks),
)
)

residual_offset += factor.residual_dim * num_factor
assert residual_offset == self.residual_dim

bsparse_jacobian = BlockSparseMatrix(
blocks={
shape: jax.tree.map(lambda *x: jnp.concatenate(x, axis=0), *blocklist)
for shape, blocklist in blocks.items()
},
bsparse_jacobian = BlockRowSparseMatrix(
block_rows=tuple(block_rows),
shape=(self.residual_dim, self.tangent_dim),
)
jac_vals = jnp.concatenate(jac_vals, axis=0)
Expand Down
21 changes: 12 additions & 9 deletions src/jaxls/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _solve_on_host(
class ConjugateGradientLinearSolver:
"""Iterative solver for sparse linear systems. Can run on CPU or GPU."""

tolerance: float = 1e-5
tolerance: float = 1e-6
inexact_step_eta: float | None = None
"""Forcing sequence parameter for inexact Newton steps. CG tolerance is set to
`eta / iteration #`.
Expand Down Expand Up @@ -172,18 +172,18 @@ def step(
self, graph: FactorGraph, state: NonlinearSolverState
) -> NonlinearSolverState:
jac_values, A_blocksparse = graph._compute_jac_values(state.vals)
A_coo = SparseCooMatrix(jac_values, graph.jac_coords_coo).as_jax_bcoo()
A_multiply = A_blocksparse.multiply
AT_multiply = A_blocksparse.transpose().multiply

# Equivalently:
# AT_multiply = lambda vec: jax.linear_transpose(
# A_blocksparse.multiply, jnp.zeros((A_blocksparse.shape[1],))
# )(vec)[0]
# linear_transpose() will return a tuple, with one element per primal.
A_multiply = A_blocksparse.multiply
AT_multiply_ = jax.linear_transpose(
A_multiply, jnp.zeros((A_blocksparse.shape[1],))
)
AT_multiply = lambda vec: AT_multiply_(vec)[0]

ATb = -AT_multiply(state.residual_vector)

if isinstance(self.linear_solver, ConjugateGradientLinearSolver):
A_coo = SparseCooMatrix(jac_values, graph.jac_coords_coo).as_jax_bcoo()
tangent = self.linear_solver._solve(
A_coo,
# We could also use (lambd * ATA_diagonals * vec) for
Expand Down Expand Up @@ -237,7 +237,10 @@ def step(
# For Levenberg-Marquardt, we need to evaluate the step quality.
else:
step_quality = (proposed_cost - state.cost) / (
jnp.sum((A_coo @ tangent + state.residual_vector) ** 2) - state.cost
jnp.sum(
(A_blocksparse.multiply(tangent) + state.residual_vector) ** 2
)
- state.cost
)
accept_flag = step_quality >= self.trust_region.step_quality_min

Expand Down
86 changes: 40 additions & 46 deletions src/jaxls/_sparse_matrices.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,64 @@
from __future__ import annotations

from typing import Hashable

import jax
import jax.experimental.sparse
import jax_dataclasses as jdc
from jax import numpy as jnp


@jdc.pytree_dataclass
class MatrixBlock:
class MatrixBlockRow:
start_row: jax.Array
start_col: jax.Array
values: jax.Array
"""Row indices of the start of each block. Shape should be `(num_blocks,)`."""
start_cols: tuple[jax.Array, ...]
"""Column indices of the start of each block. Shape in tuple should be `(num_blocks,)`."""
blocks: tuple[jax.Array, ...]
"""Blocks of matrix. Shape in tuple should be `(num_blocks, rows, cols)`."""

def treedef(self) -> Hashable:
return tuple(block.shape for block in self.blocks)


@jdc.pytree_dataclass
class BlockSparseMatrix:
blocks: dict[tuple[int, int], MatrixBlock]
"""Map from block shape to block (values, start row, start col)."""
class BlockRowSparseMatrix:
block_rows: tuple[MatrixBlockRow, ...]
"""Batched blocks. Each element in the tuple has a list of consecutive
blocks."""
shape: jdc.Static[tuple[int, int]]
"""Shape of matrix."""

def transpose(self) -> BlockSparseMatrix:
new_blocks = {}
for block_shape, block in self.blocks.items():
new_block = MatrixBlock(
start_row=block.start_col,
start_col=block.start_row,
values=jnp.swapaxes(block.values, -1, -2),
)
new_blocks[block_shape[::-1]] = new_block
return BlockSparseMatrix(new_blocks, (self.shape[1], self.shape[0]))

def multiply(self, target: jax.Array) -> jax.Array:
result = jnp.zeros(self.shape[0])
for block_shape, block in self.blocks.items():
start_row, start_col = block.start_row, block.start_col
assert len(start_row.shape) == 1
assert len(start_col.shape) == 1
values = block.values
assert values.shape == (len(start_row), *block_shape)

def multiply_one_block(col, vals) -> jax.Array:
target_slice = jax.lax.dynamic_slice_in_dim(
target, col, block_shape[1], axis=0
assert target.ndim == 1

def multiply_one_block_row(
start_cols: tuple[jax.Array, ...], blocks: tuple[jax.Array, ...]
) -> jax.Array:
vecs = list[jax.Array]()
for start_col, block in zip(start_cols, blocks):
vecs.append(
jnp.einsum(
"ij,j->i",
block,
jax.lax.dynamic_slice_in_dim(target, start_col, block.shape[1]),
)
)
return jnp.einsum("ij,j->i", vals, target_slice)
return jax.tree.reduce(jnp.add, vecs)

update_indices = start_row[:, None] + jnp.arange(block_shape[0])[None, :]
result = result.at[update_indices].add(
jax.vmap(multiply_one_block)(start_col, values)
out_slices = []
for block_row in self.block_rows:
# Do matrix multiplies for all blocks in block-row.
vecs = jax.vmap(multiply_one_block_row)(
block_row.start_cols, block_row.blocks
)
return result

def todense(self) -> jax.Array:
result = jnp.zeros(self.shape)
for block_shape, block in self.blocks.items():
start_row, start_col = block.start_row, block.start_col
assert len(start_row.shape) == 1
assert len(start_col.shape) == 1
values = block.values
assert values.shape == (len(start_row), *block_shape)

row_indices = start_row[:, None] + jnp.arange(block_shape[0])[None, :]
col_indices = start_col[:, None] + jnp.arange(block_shape[1])[None, :]
result = result.at[row_indices, col_indices].set(values)
proto_block = block_row.blocks[0]
assert proto_block.ndim == 3 # (batch, rows, cols)
assert vecs.shape == (proto_block.shape[0], proto_block.shape[1])
out_slices.append(vecs.flatten())
assert block_row.start_row.shape == (vecs.shape[0],)

result = jnp.concatenate(out_slices, axis=0)
return result


Expand Down

0 comments on commit a87c219

Please sign in to comment.