Skip to content

Commit

Permalink
Housekeeping, shift to common infrastructure
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Nov 27, 2021
1 parent 3f190a2 commit 52d8776
Show file tree
Hide file tree
Showing 40 changed files with 228 additions and 505 deletions.
5 changes: 1 addition & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,10 @@ We use Python 3.8 and miniconda for development.
In addition to JAX and the first-party dependencies listed above, note that
this also includes various other helpers:

- **[datargs](https://github.com/brentyi/datargs)** is super useful for
building type-safe argument parsers.
- **[torch](https://github.com/pytorch/pytorch)**'s `Dataset` and
`DataLoader` interfaces are used for training.
- **[fannypack](https://github.com/brentyi/fannypack)** contains some
utilities for downloading datasets, working with PDB, polling repository
commit hashes.
utilities for working with hdf5 files.

The `requirements.txt` provided will install the CPU version of JAX by default.
For CUDA support, please see [instructions](http://github.com/google/jax) from
Expand Down
8 changes: 4 additions & 4 deletions cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from typing import Dict, List, Optional, Tuple

import beautifultable
import dcargs
import fifteen
import numpy as onp
import termcolor

from lib import experiment_files, utils

MetricDict = Dict[str, float]


Expand Down Expand Up @@ -130,7 +130,7 @@ def main(args: Args) -> None:
num_folds = 10
for fold in range(num_folds):
# Read evaluation metrics
experiment = experiment_files.ExperimentFiles(
experiment = fifteen.experiments.Experiment(
identifier=f"{experiment_name}/fold_{fold}",
verbose=False,
)
Expand Down Expand Up @@ -164,5 +164,5 @@ def main(args: Args) -> None:


if __name__ == "__main__":
args = utils.parse_args(Args)
args = dcargs.parse(Args)
main(args)
3 changes: 2 additions & 1 deletion data/gen/make_kitti_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import sys

import fannypack
import fifteen
import numpy as onp
from PIL import Image
from tqdm.auto import tqdm
Expand Down Expand Up @@ -123,7 +124,7 @@ def timestep_from_path(path: pathlib.Path) -> int:

if __name__ == "__main__":

fannypack.utils.pdb_safety_net()
fifteen.utils.pdb_safety_net()

path: pathlib.Path
directories = sorted(
Expand Down
14 changes: 8 additions & 6 deletions kitti_transfer_ablation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import dataclasses

import dcargs
import fifteen
import jax_dataclasses
from tqdm.auto import tqdm

from lib import experiment_files, kitti, utils
from lib import kitti


@dataclasses.dataclass
Expand All @@ -17,10 +19,10 @@ class Args:
def main(args: Args) -> None:
for dataset_fold in tqdm(range(10)):
# Experiments to transfer noise models across
ekf_experiment = experiment_files.ExperimentFiles(
ekf_experiment = fifteen.experiments.Experiment(
identifier=args.ekf_experiment_identifier.format(dataset_fold=dataset_fold)
).assert_exists()
fg_experiment = experiment_files.ExperimentFiles(
fg_experiment = fifteen.experiments.Experiment(
identifier=args.fg_experiment_identifier.format(dataset_fold=dataset_fold)
).assert_exists()

Expand Down Expand Up @@ -64,14 +66,14 @@ def main(args: Args) -> None:
)(fg_train_state)

# Write metrics to kitti
experiment_files.ExperimentFiles(
fifteen.experiments.Experiment(
identifier=f"kitti/ekf/hetero/trained_on_fg/fold_{dataset_fold}"
).clear().write_metadata("best_val_metrics", ekf_metrics)
experiment_files.ExperimentFiles(
fifteen.experiments.Experiment(
identifier=f"kitti/fg/hetero/trained_on_ekf/fold_{dataset_fold}"
).clear().write_metadata("best_val_metrics", fg_metrics)


if __name__ == "__main__":
args = utils.parse_args(Args)
args = dcargs.parse(Args)
main(args)
80 changes: 0 additions & 80 deletions lib/array_struct.py

This file was deleted.

46 changes: 30 additions & 16 deletions lib/disk/data.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,45 @@
"""Structures (pytrees) for working with disk data."""

from typing import cast

import jax
import jax_dataclasses
import numpy as onp
from jax import numpy as jnp
from typing_extensions import Annotated

from .. import array_struct
null_array = cast(jnp.ndarray, None)
# ^Placeholder value to be used as a dataclass field default, to enable structs that
# contain only a partial set of values.
#
# An intuitive solution is to populate fields with a dummy default array like
# `jnp.empty(shape=(0,))`), but this can cause silent broadcasting/tracing issues.
#
# So instead we use `None` as the default value. Which is nice because it leads to loud
# runtime errors when uninitialized values are accidentally used.
#
# Note that the correct move would be to hint fields as `Optional[jnp.ndarray]`, but
# this would result in code that's littered with `assert __ is not None` statements
# and/or casts. Which is annoying. So instead we just pretend `None` is an array,


@jax_dataclasses.pytree_dataclass
class _DiskStruct(array_struct.ShapeAnnotatedStruct):
class _DiskStruct(jax_dataclasses.EnforcedAnnotationsMixin):
"""Values in our dataset."""

image: Annotated[jnp.ndarray, (120, 120, 3)] = array_struct.null_array
visible_pixels_count: Annotated[jnp.ndarray, ()] = array_struct.null_array
position: Annotated[jnp.ndarray, (2,)] = array_struct.null_array
velocity: Annotated[jnp.ndarray, (2,)] = array_struct.null_array
image: Annotated[jnp.ndarray, (120, 120, 3), jnp.floating] = null_array
visible_pixels_count: Annotated[jnp.ndarray, (), jnp.floating] = null_array
position: Annotated[jnp.ndarray, (2,), jnp.floating] = null_array
velocity: Annotated[jnp.ndarray, (2,), jnp.floating] = null_array


_DATASET_MEANS = _DiskStruct(
_DATASET_MEANS = dict(
image=onp.array([24.30598765, 29.76503314, 29.86749727], dtype=onp.float32), # type: ignore
position=onp.array([-0.08499543, 0.07917813], dtype=onp.float32), # type: ignore
velocity=onp.array([0.02876372, 0.06096543], dtype=onp.float32), # type: ignore
visible_pixels_count=104.87143, # type: ignore
)
_DATASET_STD_DEVS = _DiskStruct(
_DATASET_STD_DEVS = dict(
image=onp.array([74.88154621, 81.87872827, 82.00088091], dtype=onp.float32), # type: ignore
position=onp.array([30.53421, 30.84835], dtype=onp.float32), # type: ignore
velocity=onp.array([6.636913, 6.647381], dtype=onp.float32), # type: ignore
Expand All @@ -39,8 +53,8 @@ def normalize(self, scale_only: bool = False) -> "DiskStructNormalized":
"""Normalize contents."""

def _norm(value, mean, std):
if value is array_struct.null_array:
return array_struct.null_array
if value is null_array:
return null_array

if scale_only:
return value / std
Expand All @@ -51,8 +65,8 @@ def _norm(value, mean, std):
**jax.tree_map(
_norm,
vars(self),
vars(_DATASET_MEANS),
vars(_DATASET_STD_DEVS),
_DATASET_MEANS,
_DATASET_STD_DEVS,
)
)

Expand All @@ -63,8 +77,8 @@ def unnormalize(self, scale_only: bool = False) -> "DiskStructRaw":
"""Unnormalize contents."""

def _unnorm(value, mean, std):
if value is array_struct.null_array:
return array_struct.null_array
if value is null_array:
return null_array

if scale_only:
return value * std
Expand All @@ -75,7 +89,7 @@ def _unnorm(value, mean, std):
**jax.tree_map(
_unnorm,
vars(self),
vars(_DATASET_MEANS),
vars(_DATASET_STD_DEVS),
_DATASET_MEANS,
_DATASET_STD_DEVS,
)
)
6 changes: 5 additions & 1 deletion lib/disk/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

import fannypack
import jax
import torch

# For future projects, we probably want to use fifteen.data.DataLoader instead of the
# torch DataLoader, but keeping the torch one because that's what was used for the paper
# results.
import torch.utils.data

from .. import utils
from . import data, experiment_config
Expand Down
6 changes: 6 additions & 0 deletions lib/disk/experiment_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""Experiment configurations for disk task.
Note: we'd structure these very differently if we were to rewrite this code,
particularly to replace inheritance with nested dataclasses for common fields. (the
latter is now supported in `dcargs`)"""

import dataclasses
import enum
from typing import Literal
Expand Down
2 changes: 1 addition & 1 deletion lib/disk/networks_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class DiskLstm(nn.Module):

@nn.compact
def __call__(self, inputs: data.DiskStructNormalized) -> jnp.ndarray:
N, T = inputs.check_shapes_and_get_batch_axes()
N, T = inputs.get_batch_axes()
images = inputs.image
assert images.shape == (N, T, 120, 120, 3)

Expand Down
15 changes: 8 additions & 7 deletions lib/disk/training_ekf.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Any, Optional, Tuple

import fifteen
import jax
import jax_dataclasses
import optax
from jax import numpy as jnp

from .. import experiment_files, manifold_ekf, utils
from .. import manifold_ekf, utils
from . import data, experiment_config, fg_system, networks

Pytree = Any
Expand Down Expand Up @@ -42,7 +43,7 @@ def initialize(
) -> "TrainState":
# Load position CNN
cnn_model, cnn_params = networks.make_position_cnn(seed=config.random_seed)
cnn_params = experiment_files.ExperimentFiles(
cnn_params = fifteen.experiments.Experiment(
identifier=config.pretrained_virtual_sensor_identifier.format(
dataset_fold=config.dataset_fold
)
Expand Down Expand Up @@ -88,17 +89,17 @@ def initialize(
@jax.jit
def training_step(
self, batch: data.DiskStructNormalized
) -> Tuple["TrainState", experiment_files.TensorboardLogData]:
) -> Tuple["TrainState", fifteen.experiments.TensorboardLogData]:

# Shape checks
(batch_size, sequence_length) = batch.check_shapes_and_get_batch_axes()
(batch_size, sequence_length) = batch.get_batch_axes()
assert sequence_length == self.config.train_sequence_length

def compute_loss_single(
trajectory: data.DiskStructNormalized,
learnable_params: Pytree,
) -> jnp.ndarray:
(timesteps,) = trajectory.check_shapes_and_get_batch_axes()
(timesteps,) = trajectory.get_batch_axes()

unnormed_position = (
data.DiskStructNormalized(position=trajectory.position)
Expand Down Expand Up @@ -164,7 +165,7 @@ def compute_loss(
)

# Log data
log_data = experiment_files.TensorboardLogData(
log_data = fifteen.experiments.TensorboardLogData(
scalars={
"train/training_loss": loss,
"train/gradient_norm": optax.global_norm(grads),
Expand All @@ -184,7 +185,7 @@ def run_ekf(
trajectory: data.DiskStructNormalized,
learnable_params: Optional[Pytree] = None,
) -> manifold_ekf.MultivariateGaussian[fg_system.State]:
(timesteps,) = trajectory.check_shapes_and_get_batch_axes()
(timesteps,) = trajectory.get_batch_axes()

# Some type aliases
Belief = manifold_ekf.MultivariateGaussian[fg_system.State]
Expand Down
Loading

0 comments on commit 52d8776

Please sign in to comment.