Skip to content

Commit

Permalink
Pretty print
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jul 4, 2024
1 parent ed9d014 commit 186f979
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 75 deletions.
16 changes: 16 additions & 0 deletions jaxfg2/_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,22 @@ def get_value[T](self, var: Var[T]) -> T:
index = jnp.searchsorted(self.ids_from_type[var_type], var.id)
return jax.tree.map(lambda x: x[index], self.vals_from_type[var_type])

def __repr__(self) -> str:
out_lines = list[str]()

for var_type, ids in self.ids_from_type.items():
for i in range(ids.shape[-1]):
batch_axes = ids.shape[:-1]
val = jax.tree_map(
lambda x: x.take(indices=i, axis=len(batch_axes)),
self.vals_from_type[var_type],
)
out_lines.append(
f" {var_type.__name__}(" + f"{ids[..., i]}): ".ljust(8) + f"{val},"
)

return f"VarValues(\n{'\n'.join(out_lines)}\n)"

def get_stacked_value[T](self, var_type: type[Var[T]]) -> T:
"""Get the value of all variables of a specific type."""
return self.vals_from_type[var_type]
Expand Down
87 changes: 12 additions & 75 deletions scripts/pose_graph_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,95 +23,32 @@

# Create factors: each defines a conditional probability distribution over some
# variables.


def prior_two(vals, var, orig, var1, orig1):
return jnp.concatenate(
[(vals[var] @ orig.inverse()).log(), (vals[var1] @ orig1.inverse()).log()]
)


def prior_one(vals, var, orig):
return (vals[var] @ orig.inverse()).log()


# def between(vals, var0, var1, delta):
# return ((vals[var0] @ delta).inverse() @ vals[var1]).log()


factors = [
# Prior factor for pose 0.
jaxfg2.Factor.make(
# lambda vals, var0, var, init: (
# vals[var] @ init
# ).log(),
prior_two,
(
pose_variables[0],
jaxlie.SE2.from_translation(jnp.array([100.0, 10.0])),
pose_variables[1],
jaxlie.SE2.from_translation(jnp.array([200.0, 20.0])),
),
),
jaxfg2.Factor.make(
# lambda vals, var0, var, init: (
# vals[var] @ init
# ).log(),
prior_two,
(
pose_variables[0],
jaxlie.SE2.from_translation(jnp.array([100.0, 10.0])),
pose_variables[1],
jaxlie.SE2.from_translation(jnp.array([200.0, 20.0])),
),
lambda vals, var, init: (vals[var] @ init.inverse()).log(),
(pose_variables[0], jaxlie.SE2.from_xy_theta(0.0, 0.0, 0.0)),
),
# Prior factor for pose 1.
jaxfg2.Factor.make(
# lambda vals, var0, var, init: (
# vals[var] @ init
# ).log(),
prior_one,
(
pose_variables[0],
jaxlie.SE2.from_translation(jnp.array([100.0, 10.0])),
),
lambda vals, var, init: (vals[var] @ init.inverse()).log(),
(pose_variables[1], jaxlie.SE2.from_xy_theta(2.0, 0.0, 0.0)),
),
# "Between" factor.
jaxfg2.Factor.make(
# lambda vals, var0, var, init: (
# vals[var] @ init
# ).log(),
prior_one,
(
pose_variables[1],
jaxlie.SE2.from_translation(jnp.array([200.0, 20.0])),
),
lambda vals, var0, var1, delta: (
(vals[var0].inverse() @ vals[var1]) @ delta.inverse()
).log(),
(pose_variables[0], pose_variables[1], jaxlie.SE2.from_xy_theta(1.0, 0.0, 0.0)),
),
# jaxfg2.Factor.make(
# # lambda vals, var: (
# # vals[var] @ jaxlie.SE2.from_translation(jnp.array([200.0, 20.0]))
# # ).log(),
# # lambda *args: prior(*args),
# prior,
# (pose_variables[1], jaxlie.SE2.from_translation(jnp.array([200.0, 20.0]))),
# ),
# jaxfg2.Factor.make(
# # lambda vals, var: (
# # vals[var] @ jaxlie.SE2.from_translation(jnp.array([200.0, 20.0]))
# # ).log(),
# # lambda *args: prior(*args),
# between,
# (
# pose_variables[0],
# pose_variables[1],
# jaxlie.SE2.from_translation(jnp.array([50.0, 10.0])),
# ),
# ),
]

# Create our "stacked" factor graph. (this is the only kind of factor graph)
#
# This goes through factors, and preprocesses them to enable vectorization of
# computations. If we have 1000 PriorFactor objects, we stack all of the associated
# values and perform a batched operation that computes all 1000 residuals.
graph = jaxfg2.StackedFactorGraph.make(factors, vars=pose_variables)
graph = jaxfg2.StackedFactorGraph.make(factors, pose_variables)


# Create an assignments object, which you can think of as a (variable => value) mapping.
Expand Down

0 comments on commit 186f979

Please sign in to comment.