Skip to content

Commit

Permalink
Versioning improvements, fix some runtime errors
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jan 4, 2022
1 parent 52d8776 commit 6044678
Show file tree
Hide file tree
Showing 34 changed files with 124 additions and 91 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ Code release for our IROS 2021 conference paper:
<sup>2</sup><em>Max Planck Institute for Intelligent Systems,
`[email protected]`</em>


**Bibtex:**
```
@inproceedings{yi2021iros,
author={Brent Yi and Michelle Lee and Alina Kloss and Roberto Mart\'in-Mart\'in and Jeannette Bohg},
title = {Differentiable Factor Graph Optimization for Learning Smoothers},
year = 2021,
BOOKTITLE = {2021 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)}
}
```

---

This repository contains models, training scripts, and experimental results, and
Expand Down
3 changes: 2 additions & 1 deletion cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def main(args: Args) -> None:
for fold in range(num_folds):
# Read evaluation metrics
experiment = fifteen.experiments.Experiment(
identifier=f"{experiment_name}/fold_{fold}",
data_dir=pathlib.Path("./experiments/")
/ f"{experiment_name}/fold_{fold}",
verbose=False,
)
try:
Expand Down
5 changes: 3 additions & 2 deletions data/gen/make_kitti_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import fannypack
import fifteen
import jax_dataclasses
import numpy as onp
from PIL import Image
from tqdm.auto import tqdm
Expand Down Expand Up @@ -157,15 +158,15 @@ def timestep_from_path(path: pathlib.Path) -> int:
traj_file.resize(2)

# Load data from first camera
traj_file[0] = vars(
traj_file[0] = jax_dataclasses.asdict(
load_data(
pose_txt=directory.parent / f"{dataset_id}_image1.txt",
image_dir=directory / "image_2",
)
)

# Load data from second camera
traj_file[1] = vars(
traj_file[1] = jax_dataclasses.asdict(
load_data(
pose_txt=directory.parent / f"{dataset_id}_image2.txt",
image_dir=directory / "image_3",
Expand Down
13 changes: 9 additions & 4 deletions kitti_transfer_ablation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import pathlib

import dcargs
import fifteen
Expand All @@ -20,10 +21,12 @@ def main(args: Args) -> None:
for dataset_fold in tqdm(range(10)):
# Experiments to transfer noise models across
ekf_experiment = fifteen.experiments.Experiment(
identifier=args.ekf_experiment_identifier.format(dataset_fold=dataset_fold)
data_dir=pathlib.Path("./experiments/")
/ args.ekf_experiment_identifier.format(dataset_fold=dataset_fold)
).assert_exists()
fg_experiment = fifteen.experiments.Experiment(
identifier=args.fg_experiment_identifier.format(dataset_fold=dataset_fold)
data_dir=pathlib.Path("./experiments/")
/ args.fg_experiment_identifier.format(dataset_fold=dataset_fold)
).assert_exists()

# Read experiment configurations for each experiment
Expand Down Expand Up @@ -67,10 +70,12 @@ def main(args: Args) -> None:

# Write metrics to kitti
fifteen.experiments.Experiment(
identifier=f"kitti/ekf/hetero/trained_on_fg/fold_{dataset_fold}"
data_dir=pathlib.Path("./experiments/")
/ f"kitti/ekf/hetero/trained_on_fg/fold_{dataset_fold}"
).clear().write_metadata("best_val_metrics", ekf_metrics)
fifteen.experiments.Experiment(
identifier=f"kitti/fg/hetero/trained_on_ekf/fold_{dataset_fold}"
data_dir=pathlib.Path("./experiments/")
/ f"kitti/fg/hetero/trained_on_ekf/fold_{dataset_fold}"
).clear().write_metadata("best_val_metrics", fg_metrics)


Expand Down
4 changes: 2 additions & 2 deletions lib/disk/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _norm(value, mean, std):
return DiskStructNormalized(
**jax.tree_map(
_norm,
vars(self),
jax_dataclasses.asdict(self),
_DATASET_MEANS,
_DATASET_STD_DEVS,
)
Expand All @@ -88,7 +88,7 @@ def _unnorm(value, mean, std):
return DiskStructRaw(
**jax.tree_map(
_unnorm,
vars(self),
jax_dataclasses.asdict(self),
_DATASET_MEANS,
_DATASET_STD_DEVS,
)
Expand Down
3 changes: 2 additions & 1 deletion lib/disk/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import fannypack
import jax
import numpy as onp

# For future projects, we probably want to use fifteen.data.DataLoader instead of the
# torch DataLoader, but keeping the torch one because that's what was used for the paper
Expand Down Expand Up @@ -93,7 +94,7 @@ def load_trajectories(

traj = data.DiskStructRaw(
**{
field.name: trajectory[field.name]
field.name: trajectory[field.name].astype(onp.float32)
for field in dataclasses.fields(data.DiskStructRaw)
}
).normalize()
Expand Down
8 changes: 4 additions & 4 deletions lib/disk/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ def make(units: int, layers: int, output_dim: int):
return MLP(units=units, layers=layers, output_dim=output_dim)

@nn.compact
def __call__(self, inputs: jnp.ndarray):
def __call__(self, inputs: jnp.ndarray): # type: ignore
x = inputs

for i in range(self.layers):
x = nn.Dense(self.units, kernel_init=relu_layer_init)(x)
x = nn.Dense(features=self.units, kernel_init=relu_layer_init)(x)
x = nn.relu(x)

x = nn.Dense(self.output_dim, kernel_init=linear_layer_init)(x)
x = nn.Dense(features=self.output_dim, kernel_init=linear_layer_init)(x)
return x


Expand All @@ -96,7 +96,7 @@ class DiskVirtualSensor(nn.Module):
output_dim: int = 2

@nn.compact
def __call__(self, inputs: jnp.ndarray):
def __call__(self, inputs: jnp.ndarray): # type: ignore
x = inputs
N = x.shape[0]
assert x.shape == (N, 120, 120, 3), x.shape
Expand Down
14 changes: 7 additions & 7 deletions lib/disk/networks_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ class DiskLstm(nn.Module):
bidirectional: bool

@nn.compact
def __call__(self, inputs: data.DiskStructNormalized) -> jnp.ndarray:
def __call__(self, inputs: data.DiskStructNormalized) -> jnp.ndarray: # type: ignore
N, T = inputs.get_batch_axes()
images = inputs.image
assert images.shape == (N, T, 120, 120, 3)

# Initial carry by encoding ground-truth initial state
initial_carry = nn.Dense(32, kernel_init=networks.relu_layer_init)(
initial_carry = nn.Dense(features=32, kernel_init=networks.relu_layer_init)(
jnp.concatenate(
[
inputs.position[:, 0, :],
Expand All @@ -34,9 +34,9 @@ def __call__(self, inputs: data.DiskStructNormalized) -> jnp.ndarray:
)
assert initial_carry.shape == (N, 32)
initial_carry = nn.relu(initial_carry)
initial_carry = nn.Dense(32 * 2, kernel_init=networks.linear_layer_init)(
initial_carry
)
initial_carry = nn.Dense(
features=32 * 2, kernel_init=networks.linear_layer_init
)(initial_carry)
initial_carry = (initial_carry[..., :32], initial_carry[..., 32:])

# Image encoder
Expand All @@ -54,9 +54,9 @@ def __call__(self, inputs: data.DiskStructNormalized) -> jnp.ndarray:
assert x.shape == (N, T, 32)

# Output
x = nn.Dense(32, kernel_init=networks.relu_layer_init)(x)
x = nn.Dense(features=32, kernel_init=networks.relu_layer_init)(x)
x = nn.relu(x)
x = nn.Dense(2, kernel_init=networks.linear_layer_init)(x)
x = nn.Dense(features=2, kernel_init=networks.linear_layer_init)(x)
assert x.shape == (N, T, 2)

return x
5 changes: 3 additions & 2 deletions lib/disk/training_ekf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pathlib
from typing import Any, Optional, Tuple

import fifteen
Expand All @@ -10,7 +11,6 @@
from . import data, experiment_config, fg_system, networks

Pytree = Any
PRNGKey = jnp.ndarray


DiskEkf = manifold_ekf.EkfDefinition[fg_system.State, jnp.ndarray, None]
Expand Down Expand Up @@ -44,7 +44,8 @@ def initialize(
# Load position CNN
cnn_model, cnn_params = networks.make_position_cnn(seed=config.random_seed)
cnn_params = fifteen.experiments.Experiment(
identifier=config.pretrained_virtual_sensor_identifier.format(
data_dir=pathlib.Path("./experiments/")
/ config.pretrained_virtual_sensor_identifier.format(
dataset_fold=config.dataset_fold
)
).restore_checkpoint(cnn_params, prefix="best_val_params_")
Expand Down
5 changes: 3 additions & 2 deletions lib/disk/training_fg.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pathlib
from typing import Any, Optional, Tuple

import fifteen
Expand All @@ -11,7 +12,6 @@
from . import data, experiment_config, fg_system, fg_utils, networks

Pytree = Any
PRNGKey = jnp.ndarray


@jax_dataclasses.pytree_dataclass
Expand Down Expand Up @@ -43,7 +43,8 @@ def initialize(
# Load position CNN
cnn_model, cnn_params = networks.make_position_cnn(seed=config.random_seed)
cnn_params = fifteen.experiments.Experiment(
identifier=config.pretrained_virtual_sensor_identifier.format(
data_dir=pathlib.Path("./experiments/")
/ config.pretrained_virtual_sensor_identifier.format(
dataset_fold=config.dataset_fold
)
).restore_checkpoint(cnn_params, prefix="best_val_params_")
Expand Down
1 change: 0 additions & 1 deletion lib/disk/training_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from . import data, experiment_config, networks_lstm

Pytree = Any
PRNGKey = jnp.ndarray


@jax_dataclasses.pytree_dataclass
Expand Down
1 change: 0 additions & 1 deletion lib/disk/training_virtual_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from . import data, experiment_config, networks

Pytree = Any
PRNGKey = jnp.ndarray


@jax_dataclasses.pytree_dataclass
Expand Down
4 changes: 2 additions & 2 deletions lib/kitti/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _unnorm(value, mean, std):
return KittiStructRaw(
**jax.tree_map(
_unnorm,
vars(self),
jax_dataclasses.asdict(self),
_DATASET_MEANS,
_DATASET_STD_DEVS,
)
Expand Down Expand Up @@ -156,7 +156,7 @@ def _norm(value, mean, std):
return KittiStructNormalized(
**jax.tree_map(
_norm,
vars(self),
jax_dataclasses.asdict(self),
_DATASET_MEANS,
_DATASET_STD_DEVS,
)
Expand Down
6 changes: 2 additions & 4 deletions lib/kitti/fg_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

from . import data, experiment_config, fg_system, fg_utils

PRNGKey = jnp.ndarray


def compute_loss(
graph: jaxfg.core.StackedFactorGraph,
Expand All @@ -20,7 +18,7 @@ def compute_loss(
experiment_config.JointNllLossConfig,
experiment_config.SurrogateLossConfig,
],
prng_key: jnp.ndarray,
prng_key: jax.random.KeyArray,
) -> jnp.ndarray:
"""Given an updated factor graph, ground-truth trajectory, and loss config, compute
a single-trajectory loss."""
Expand All @@ -37,7 +35,7 @@ def _compute_surrogate_loss(
graph: jaxfg.core.StackedFactorGraph,
trajectory_raw: data.KittiStructRaw,
loss_config: experiment_config.SurrogateLossConfig,
prng_key: PRNGKey,
prng_key: jax.random.KeyArray,
) -> jnp.ndarray:
"""Compute an end-to-end loss."""
timesteps: int = len(tuple(graph.get_variables()))
Expand Down
21 changes: 11 additions & 10 deletions lib/kitti/networks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Neural network definitions, helpers, and related types."""

import pathlib
from typing import Any, NamedTuple, Protocol, Tuple

import fifteen
Expand All @@ -14,7 +15,6 @@
LearnableParams = Pytree
KittiVirtualSensorParameters = Pytree
StackedImages = jnp.ndarray
PRNGKey = jnp.ndarray

relu_layer_init = nn.initializers.kaiming_normal() # variance = 2.0 / fan_in
linear_layer_init = nn.initializers.lecun_normal() # variance = 1.0 / fan_in
Expand All @@ -40,7 +40,7 @@ def __call__(
self,
learnable_params: Pytree,
stacked_images: jnp.ndarray,
prng_key: jnp.ndarray,
prng_key: jax.random.KeyArray,
train: bool,
) -> RegressedUncertainties:
...
Expand All @@ -56,7 +56,7 @@ class KittiVirtualSensor(nn.Module):
output_dim: int = 4

@nn.compact
def __call__(self, inputs: jnp.ndarray, train: bool) -> jnp.ndarray:
def __call__(self, inputs: jnp.ndarray, train: bool) -> jnp.ndarray: # type: ignore
x = inputs
N = x.shape[0]
assert x.shape == (N, 50, 150, 6), x.shape
Expand Down Expand Up @@ -113,15 +113,15 @@ def __call__(self, inputs: jnp.ndarray, train: bool) -> jnp.ndarray:
x = x.reshape((N, -1)) # type: ignore

# fc1
x = nn.Dense(128, kernel_init=relu_layer_init)(x)
x = nn.Dense(features=128, kernel_init=relu_layer_init)(x)
x = nn.relu(x)

# fc2
x = nn.Dense(128, kernel_init=relu_layer_init)(x)
x = nn.Dense(features=128, kernel_init=relu_layer_init)(x)
x = nn.relu(x)

# fc3
x = nn.Dense(self.output_dim, kernel_init=linear_layer_init)(x)
x = nn.Dense(features=self.output_dim, kernel_init=linear_layer_init)(x)

assert x.shape == (N, self.output_dim)
return x
Expand Down Expand Up @@ -151,7 +151,9 @@ def load_pretrained_observation_cnn(
# Note that seed does not matter, because parameters will be read from checkpoint
model, params = make_observation_cnn(random_seed=0)

experiment = fifteen.experiments.Experiment(identifier=experiment_identifier)
experiment = fifteen.experiments.Experiment(
data_dir=pathlib.Path("./experiments/") / experiment_identifier
)
params = experiment.restore_checkpoint(params, prefix="best_val_params_")

return model, params
Expand All @@ -168,7 +170,6 @@ def make_regress_velocities(

def regress_velocities(
stacked_images: jnp.ndarray,
# prng_key: PRNGKey,
) -> RegressedVelocities:
N = stacked_images.shape[0]
assert stacked_images.shape == (N, 50, 150, 6)
Expand Down Expand Up @@ -259,7 +260,7 @@ def make_regress_uncertainties(
def regress_uncertainties(
learnable_params: ConstantUncertaintyParams,
stacked_images: jnp.ndarray,
prng_key: jnp.ndarray,
prng_key: jax.random.KeyArray,
train: bool,
) -> RegressedUncertainties:
sequence_length = stacked_images.shape[0]
Expand All @@ -286,7 +287,7 @@ def regress_uncertainties(
def regress_uncertainties(
learnable_params: HeteroscedasticUncertaintyParams,
stacked_images: jnp.ndarray,
prng_key: PRNGKey,
prng_key: jax.random.KeyArray,
train: bool,
) -> RegressedUncertainties:
sequence_length = stacked_images.shape[0]
Expand Down
Loading

0 comments on commit 6044678

Please sign in to comment.