diff --git a/examples/pose_graph_simple.py b/examples/pose_graph_simple.py index b0829c9..d53c431 100644 --- a/examples/pose_graph_simple.py +++ b/examples/pose_graph_simple.py @@ -33,10 +33,10 @@ ), # "Between" factor. jaxls.Factor( - lambda vals, var0, var1, delta: ( + lambda vals, delta, var0, var1: ( (vals[var0].inverse() @ vals[var1]) @ delta.inverse() ).log(), - (vars[0], vars[1], jaxlie.SE2.from_xy_theta(1.0, 0.0, 0.0)), + (jaxlie.SE2.from_xy_theta(1.0, 0.0, 0.0), vars[0], vars[1]), ), ] diff --git a/src/jaxls/_factor_graph.py b/src/jaxls/_factor_graph.py index 16b2b11..000f789 100644 --- a/src/jaxls/_factor_graph.py +++ b/src/jaxls/_factor_graph.py @@ -398,17 +398,22 @@ def make[*Args_]( return Factor(compute_residual, args, jac_mode) def _get_batch_axes(self) -> tuple[int, ...]: - def traverse_args(current: Any) -> tuple[int, ...]: + def traverse_args(current: Any) -> tuple[int, ...] | None: children_and_meta = default_registry.flatten_one_level(current) - assert children_and_meta is not None + if children_and_meta is None: + return None for child in children_and_meta[0]: if isinstance(child, Var): return () if isinstance(child.id, int) else child.id.shape else: - return traverse_args(child) - assert False, "No variables found in factor!" - - return traverse_args(self.args) + batch_axes = traverse_args(child) + if batch_axes is not None: + return batch_axes + return None + + batch_axes = traverse_args(self.args) + assert batch_axes is not None, "No variables found in factor!" + return batch_axes @jdc.pytree_dataclass