From 52d8776fe772d7483ff0efadddf07196a8bc8c12 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sat, 27 Nov 2021 11:50:56 -0800 Subject: [PATCH] Housekeeping, shift to common infrastructure --- README.md | 5 +- cross_validate.py | 8 +- data/gen/make_kitti_dataset.py | 3 +- kitti_transfer_ablation.py | 14 +- lib/array_struct.py | 80 ---------- lib/disk/data.py | 46 ++++-- lib/disk/data_loading.py | 6 +- lib/disk/experiment_config.py | 6 + lib/disk/networks_lstm.py | 2 +- lib/disk/training_ekf.py | 15 +- lib/disk/training_fg.py | 23 +-- lib/disk/training_lstm.py | 9 +- lib/disk/training_virtual_sensor.py | 8 +- lib/disk/validation_fg.py | 2 +- lib/experiment_files.py | 222 --------------------------- lib/kitti/data.py | 54 ++++--- lib/kitti/data_loading.py | 6 +- lib/kitti/experiment_config.py | 12 +- lib/kitti/networks.py | 4 +- lib/kitti/networks_lstm.py | 2 +- lib/kitti/training_ekf.py | 11 +- lib/kitti/training_fg.py | 9 +- lib/kitti/training_lstm.py | 15 +- lib/kitti/training_virtual_sensor.py | 6 +- lib/kitti/validation_ekf.py | 4 +- lib/kitti/validation_fg.py | 4 +- lib/kitti/validation_lstm.py | 2 +- lib/lstm_layers.py | 3 +- lib/train_state_protocol.py | 7 +- lib/utils.py | 44 +----- lib/validation_tracker.py | 6 +- requirements.txt | 5 +- train_disk_ekf.py | 11 +- train_disk_fg.py | 11 +- train_disk_lstm.py | 11 +- train_disk_virtual_sensor.py | 11 +- train_kitti_ekf.py | 10 +- train_kitti_fg.py | 11 +- train_kitti_lstm.py | 13 +- train_kitti_virtual_sensor.py | 12 +- 40 files changed, 228 insertions(+), 505 deletions(-) delete mode 100644 lib/array_struct.py delete mode 100644 lib/experiment_files.py diff --git a/README.md b/README.md index 904d164..3580c67 100644 --- a/README.md +++ b/README.md @@ -121,13 +121,10 @@ We use Python 3.8 and miniconda for development. In addition to JAX and the first-party dependencies listed above, note that this also includes various other helpers: - - **[datargs](https://github.com/brentyi/datargs)** is super useful for - building type-safe argument parsers. - **[torch](https://github.com/pytorch/pytorch)**'s `Dataset` and `DataLoader` interfaces are used for training. - **[fannypack](https://github.com/brentyi/fannypack)** contains some - utilities for downloading datasets, working with PDB, polling repository - commit hashes. + utilities for working with hdf5 files. The `requirements.txt` provided will install the CPU version of JAX by default. For CUDA support, please see [instructions](http://github.com/google/jax) from diff --git a/cross_validate.py b/cross_validate.py index f9b0965..123f398 100644 --- a/cross_validate.py +++ b/cross_validate.py @@ -5,11 +5,11 @@ from typing import Dict, List, Optional, Tuple import beautifultable +import dcargs +import fifteen import numpy as onp import termcolor -from lib import experiment_files, utils - MetricDict = Dict[str, float] @@ -130,7 +130,7 @@ def main(args: Args) -> None: num_folds = 10 for fold in range(num_folds): # Read evaluation metrics - experiment = experiment_files.ExperimentFiles( + experiment = fifteen.experiments.Experiment( identifier=f"{experiment_name}/fold_{fold}", verbose=False, ) @@ -164,5 +164,5 @@ def main(args: Args) -> None: if __name__ == "__main__": - args = utils.parse_args(Args) + args = dcargs.parse(Args) main(args) diff --git a/data/gen/make_kitti_dataset.py b/data/gen/make_kitti_dataset.py index dcf6c7d..c84ba76 100644 --- a/data/gen/make_kitti_dataset.py +++ b/data/gen/make_kitti_dataset.py @@ -17,6 +17,7 @@ import sys import fannypack +import fifteen import numpy as onp from PIL import Image from tqdm.auto import tqdm @@ -123,7 +124,7 @@ def timestep_from_path(path: pathlib.Path) -> int: if __name__ == "__main__": - fannypack.utils.pdb_safety_net() + fifteen.utils.pdb_safety_net() path: pathlib.Path directories = sorted( diff --git a/kitti_transfer_ablation.py b/kitti_transfer_ablation.py index f24b06c..7aa4d5c 100644 --- a/kitti_transfer_ablation.py +++ b/kitti_transfer_ablation.py @@ -1,9 +1,11 @@ import dataclasses +import dcargs +import fifteen import jax_dataclasses from tqdm.auto import tqdm -from lib import experiment_files, kitti, utils +from lib import kitti @dataclasses.dataclass @@ -17,10 +19,10 @@ class Args: def main(args: Args) -> None: for dataset_fold in tqdm(range(10)): # Experiments to transfer noise models across - ekf_experiment = experiment_files.ExperimentFiles( + ekf_experiment = fifteen.experiments.Experiment( identifier=args.ekf_experiment_identifier.format(dataset_fold=dataset_fold) ).assert_exists() - fg_experiment = experiment_files.ExperimentFiles( + fg_experiment = fifteen.experiments.Experiment( identifier=args.fg_experiment_identifier.format(dataset_fold=dataset_fold) ).assert_exists() @@ -64,14 +66,14 @@ def main(args: Args) -> None: )(fg_train_state) # Write metrics to kitti - experiment_files.ExperimentFiles( + fifteen.experiments.Experiment( identifier=f"kitti/ekf/hetero/trained_on_fg/fold_{dataset_fold}" ).clear().write_metadata("best_val_metrics", ekf_metrics) - experiment_files.ExperimentFiles( + fifteen.experiments.Experiment( identifier=f"kitti/fg/hetero/trained_on_ekf/fold_{dataset_fold}" ).clear().write_metadata("best_val_metrics", fg_metrics) if __name__ == "__main__": - args = utils.parse_args(Args) + args = dcargs.parse(Args) main(args) diff --git a/lib/array_struct.py b/lib/array_struct.py deleted file mode 100644 index e357282..0000000 --- a/lib/array_struct.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Helper for constructing array structures with shape checks at runtime.""" - - -import dataclasses -from typing import Optional, Tuple, cast - -from jax import numpy as jnp -from typing_extensions import get_type_hints - -null_array = cast(jnp.ndarray, None) -# ^Placeholder value to be used as a dataclass field default, to enable structs that -# contain only a partial set of values. -# -# An intuitive solution is to populate fields with a dummy default array like -# `jnp.empty(shape=(0,))`), but this can cause silent broadcasting/tracing issues. -# -# So instead we use `None` as the default value. Which is nice because it leads to loud -# runtime errors when uninitialized values are accidentally used. -# -# Note that the correct move would be to hint fields as `Optional[jnp.ndarray]`, but -# this would result in code that's littered with `assert __ is not None` statements -# and/or casts. Which is annoying. So instead we just pretend `None` is an array, - - -class ShapeAnnotatedStruct: - """Base class for dataclasses whose fields are annotated with expected shapes. Helps - with assertions + checking batch axes. - - Example of an annotated field: - - array: Annotated[jnp.ndarray, (50, 150, 3)] - - """ - - def __getattribute__(self, name): - out = super().__getattribute__(name) - assert out is not None - return out - - def check_shapes_and_get_batch_axes(self) -> Tuple[int, ...]: - """Make sure shapes of arrays are consistent with annotations, then return any - leading batch axes (which should be shared across all contained arrays).""" - - assert dataclasses.is_dataclass(self) - - annotations = get_type_hints(type(self), include_extras=True) - batch_axes: Optional[Tuple[int, ...]] = None - - # For each field... - for field in dataclasses.fields(self): - value = self.__getattribute__(field.name) - if value is null_array: - # Don't do anything for placeholder objects - continue - - # Get expected shape, sans batch axes - expected_shape = annotations[field.name].__metadata__[0] - assert isinstance(expected_shape, tuple) - - # Get actual shape - shape: Tuple[int, ...] - if isinstance(value, float): - shape = () - else: - assert hasattr(value, "shape") - shape = value.shape - - # Actual shape should be expected shape prefixed by some batch axes - if len(expected_shape) > 0: - assert shape[-len(expected_shape) :] == expected_shape - field_batch_axes = shape[: -len(expected_shape)] - else: - field_batch_axes = shape - - if batch_axes is None: - batch_axes = field_batch_axes - assert batch_axes == field_batch_axes - - assert batch_axes is not None - return batch_axes diff --git a/lib/disk/data.py b/lib/disk/data.py index 118e5b2..44afb25 100644 --- a/lib/disk/data.py +++ b/lib/disk/data.py @@ -1,31 +1,45 @@ """Structures (pytrees) for working with disk data.""" +from typing import cast + import jax import jax_dataclasses import numpy as onp from jax import numpy as jnp from typing_extensions import Annotated -from .. import array_struct +null_array = cast(jnp.ndarray, None) +# ^Placeholder value to be used as a dataclass field default, to enable structs that +# contain only a partial set of values. +# +# An intuitive solution is to populate fields with a dummy default array like +# `jnp.empty(shape=(0,))`), but this can cause silent broadcasting/tracing issues. +# +# So instead we use `None` as the default value. Which is nice because it leads to loud +# runtime errors when uninitialized values are accidentally used. +# +# Note that the correct move would be to hint fields as `Optional[jnp.ndarray]`, but +# this would result in code that's littered with `assert __ is not None` statements +# and/or casts. Which is annoying. So instead we just pretend `None` is an array, @jax_dataclasses.pytree_dataclass -class _DiskStruct(array_struct.ShapeAnnotatedStruct): +class _DiskStruct(jax_dataclasses.EnforcedAnnotationsMixin): """Values in our dataset.""" - image: Annotated[jnp.ndarray, (120, 120, 3)] = array_struct.null_array - visible_pixels_count: Annotated[jnp.ndarray, ()] = array_struct.null_array - position: Annotated[jnp.ndarray, (2,)] = array_struct.null_array - velocity: Annotated[jnp.ndarray, (2,)] = array_struct.null_array + image: Annotated[jnp.ndarray, (120, 120, 3), jnp.floating] = null_array + visible_pixels_count: Annotated[jnp.ndarray, (), jnp.floating] = null_array + position: Annotated[jnp.ndarray, (2,), jnp.floating] = null_array + velocity: Annotated[jnp.ndarray, (2,), jnp.floating] = null_array -_DATASET_MEANS = _DiskStruct( +_DATASET_MEANS = dict( image=onp.array([24.30598765, 29.76503314, 29.86749727], dtype=onp.float32), # type: ignore position=onp.array([-0.08499543, 0.07917813], dtype=onp.float32), # type: ignore velocity=onp.array([0.02876372, 0.06096543], dtype=onp.float32), # type: ignore visible_pixels_count=104.87143, # type: ignore ) -_DATASET_STD_DEVS = _DiskStruct( +_DATASET_STD_DEVS = dict( image=onp.array([74.88154621, 81.87872827, 82.00088091], dtype=onp.float32), # type: ignore position=onp.array([30.53421, 30.84835], dtype=onp.float32), # type: ignore velocity=onp.array([6.636913, 6.647381], dtype=onp.float32), # type: ignore @@ -39,8 +53,8 @@ def normalize(self, scale_only: bool = False) -> "DiskStructNormalized": """Normalize contents.""" def _norm(value, mean, std): - if value is array_struct.null_array: - return array_struct.null_array + if value is null_array: + return null_array if scale_only: return value / std @@ -51,8 +65,8 @@ def _norm(value, mean, std): **jax.tree_map( _norm, vars(self), - vars(_DATASET_MEANS), - vars(_DATASET_STD_DEVS), + _DATASET_MEANS, + _DATASET_STD_DEVS, ) ) @@ -63,8 +77,8 @@ def unnormalize(self, scale_only: bool = False) -> "DiskStructRaw": """Unnormalize contents.""" def _unnorm(value, mean, std): - if value is array_struct.null_array: - return array_struct.null_array + if value is null_array: + return null_array if scale_only: return value * std @@ -75,7 +89,7 @@ def _unnorm(value, mean, std): **jax.tree_map( _unnorm, vars(self), - vars(_DATASET_MEANS), - vars(_DATASET_STD_DEVS), + _DATASET_MEANS, + _DATASET_STD_DEVS, ) ) diff --git a/lib/disk/data_loading.py b/lib/disk/data_loading.py index 80cdf08..9e20ad9 100644 --- a/lib/disk/data_loading.py +++ b/lib/disk/data_loading.py @@ -4,7 +4,11 @@ import fannypack import jax -import torch + +# For future projects, we probably want to use fifteen.data.DataLoader instead of the +# torch DataLoader, but keeping the torch one because that's what was used for the paper +# results. +import torch.utils.data from .. import utils from . import data, experiment_config diff --git a/lib/disk/experiment_config.py b/lib/disk/experiment_config.py index 283a161..d246692 100644 --- a/lib/disk/experiment_config.py +++ b/lib/disk/experiment_config.py @@ -1,3 +1,9 @@ +"""Experiment configurations for disk task. + +Note: we'd structure these very differently if we were to rewrite this code, +particularly to replace inheritance with nested dataclasses for common fields. (the +latter is now supported in `dcargs`)""" + import dataclasses import enum from typing import Literal diff --git a/lib/disk/networks_lstm.py b/lib/disk/networks_lstm.py index 4bcbce7..8f9ea0b 100644 --- a/lib/disk/networks_lstm.py +++ b/lib/disk/networks_lstm.py @@ -18,7 +18,7 @@ class DiskLstm(nn.Module): @nn.compact def __call__(self, inputs: data.DiskStructNormalized) -> jnp.ndarray: - N, T = inputs.check_shapes_and_get_batch_axes() + N, T = inputs.get_batch_axes() images = inputs.image assert images.shape == (N, T, 120, 120, 3) diff --git a/lib/disk/training_ekf.py b/lib/disk/training_ekf.py index d265604..35578f8 100644 --- a/lib/disk/training_ekf.py +++ b/lib/disk/training_ekf.py @@ -1,11 +1,12 @@ from typing import Any, Optional, Tuple +import fifteen import jax import jax_dataclasses import optax from jax import numpy as jnp -from .. import experiment_files, manifold_ekf, utils +from .. import manifold_ekf, utils from . import data, experiment_config, fg_system, networks Pytree = Any @@ -42,7 +43,7 @@ def initialize( ) -> "TrainState": # Load position CNN cnn_model, cnn_params = networks.make_position_cnn(seed=config.random_seed) - cnn_params = experiment_files.ExperimentFiles( + cnn_params = fifteen.experiments.Experiment( identifier=config.pretrained_virtual_sensor_identifier.format( dataset_fold=config.dataset_fold ) @@ -88,17 +89,17 @@ def initialize( @jax.jit def training_step( self, batch: data.DiskStructNormalized - ) -> Tuple["TrainState", experiment_files.TensorboardLogData]: + ) -> Tuple["TrainState", fifteen.experiments.TensorboardLogData]: # Shape checks - (batch_size, sequence_length) = batch.check_shapes_and_get_batch_axes() + (batch_size, sequence_length) = batch.get_batch_axes() assert sequence_length == self.config.train_sequence_length def compute_loss_single( trajectory: data.DiskStructNormalized, learnable_params: Pytree, ) -> jnp.ndarray: - (timesteps,) = trajectory.check_shapes_and_get_batch_axes() + (timesteps,) = trajectory.get_batch_axes() unnormed_position = ( data.DiskStructNormalized(position=trajectory.position) @@ -164,7 +165,7 @@ def compute_loss( ) # Log data - log_data = experiment_files.TensorboardLogData( + log_data = fifteen.experiments.TensorboardLogData( scalars={ "train/training_loss": loss, "train/gradient_norm": optax.global_norm(grads), @@ -184,7 +185,7 @@ def run_ekf( trajectory: data.DiskStructNormalized, learnable_params: Optional[Pytree] = None, ) -> manifold_ekf.MultivariateGaussian[fg_system.State]: - (timesteps,) = trajectory.check_shapes_and_get_batch_axes() + (timesteps,) = trajectory.get_batch_axes() # Some type aliases Belief = manifold_ekf.MultivariateGaussian[fg_system.State] diff --git a/lib/disk/training_fg.py b/lib/disk/training_fg.py index 1c3b5d9..126f76d 100644 --- a/lib/disk/training_fg.py +++ b/lib/disk/training_fg.py @@ -1,12 +1,13 @@ from typing import Any, Optional, Tuple +import fifteen import jax import jax_dataclasses import jaxfg import optax from jax import numpy as jnp -from .. import experiment_files, utils +from .. import utils from . import data, experiment_config, fg_system, fg_utils, networks Pytree = Any @@ -41,7 +42,7 @@ def initialize( ) -> "TrainState": # Load position CNN cnn_model, cnn_params = networks.make_position_cnn(seed=config.random_seed) - cnn_params = experiment_files.ExperimentFiles( + cnn_params = fifteen.experiments.Experiment( identifier=config.pretrained_virtual_sensor_identifier.format( dataset_fold=config.dataset_fold ) @@ -89,7 +90,7 @@ def update_factor_graph( learnable_params: Optional[Pytree] = None, ) -> jaxfg.core.StackedFactorGraph: # Shape checks - (sequence_length,) = trajectory.check_shapes_and_get_batch_axes() + (sequence_length,) = trajectory.get_batch_axes() # Optional parameters default to self.* if graph_template is None: @@ -136,16 +137,16 @@ def update_factor_graph( @jax.jit def training_step( self, batch: data.DiskStructNormalized - ) -> Tuple["TrainState", experiment_files.TensorboardLogData]: + ) -> Tuple["TrainState", fifteen.experiments.TensorboardLogData]: # Shape checks - (batch_size, sequence_length) = batch.check_shapes_and_get_batch_axes() + (batch_size, sequence_length) = batch.get_batch_axes() assert sequence_length == self.config.train_sequence_length def compute_loss_single( trajectory: data.DiskStructNormalized, learnable_params: Pytree, - ) -> Tuple[jnp.ndarray, experiment_files.TensorboardLogData]: + ) -> Tuple[jnp.ndarray, fifteen.experiments.TensorboardLogData]: graph = self.update_factor_graph( trajectory=trajectory, graph_template=self.graph_template, @@ -190,7 +191,7 @@ def compute_loss_single( ** 2 ) - log_data = experiment_files.TensorboardLogData( + log_data = fifteen.experiments.TensorboardLogData( histograms={ "regressed_uncertainties": graph.factor_stacks[ 0 @@ -201,14 +202,14 @@ def compute_loss_single( def compute_loss( learnable_params: Pytree, - ) -> Tuple[jnp.ndarray, experiment_files.TensorboardLogData]: + ) -> Tuple[jnp.ndarray, fifteen.experiments.TensorboardLogData]: losses, log_data = jax.vmap(compute_loss_single, in_axes=(0, None))( batch, learnable_params ) return jnp.mean(losses), log_data # Compute loss + backprop => apply gradient transforms => update parameters - log_data: experiment_files.TensorboardLogData + log_data: fifteen.experiments.TensorboardLogData (loss, log_data), grads = jax.value_and_grad(compute_loss, has_aux=True)( self.learnable_params ) @@ -221,8 +222,8 @@ def compute_loss( ) # Log data - log_data = log_data.extend( - scalars={ + log_data = log_data.merge_scalars( + { "train/training_loss": loss, "train/gradient_norm": optax.global_norm(grads), }, diff --git a/lib/disk/training_lstm.py b/lib/disk/training_lstm.py index 717e805..4da5e4a 100644 --- a/lib/disk/training_lstm.py +++ b/lib/disk/training_lstm.py @@ -1,11 +1,12 @@ from typing import Any, Tuple +import fifteen import jax import jax_dataclasses import optax from jax import numpy as jnp -from .. import experiment_files, utils +from .. import utils from . import data, experiment_config, networks_lstm Pytree = Any @@ -63,10 +64,10 @@ def initialize( @jax.jit def training_step( self, batch: data.DiskStructNormalized - ) -> Tuple["TrainState", experiment_files.TensorboardLogData]: + ) -> Tuple["TrainState", fifteen.experiments.TensorboardLogData]: # Shape checks - (batch_size, sequence_length) = batch.check_shapes_and_get_batch_axes() + (batch_size, sequence_length) = batch.get_batch_axes() assert sequence_length == self.config.train_sequence_length def compute_loss( @@ -87,7 +88,7 @@ def compute_loss( ) # Log data - log_data = experiment_files.TensorboardLogData( + log_data = fifteen.experiments.TensorboardLogData( scalars={ "train/training_loss": loss, "train/gradient_norm": optax.global_norm(grads), diff --git a/lib/disk/training_virtual_sensor.py b/lib/disk/training_virtual_sensor.py index 2552472..4b90c6b 100644 --- a/lib/disk/training_virtual_sensor.py +++ b/lib/disk/training_virtual_sensor.py @@ -1,11 +1,11 @@ from typing import Any, Optional, Tuple +import fifteen import jax import jax_dataclasses import optax from jax import numpy as jnp -from .. import experiment_files from . import data, experiment_config, networks Pytree = Any @@ -70,11 +70,11 @@ def compute_loss( @jax.jit def training_step( self, batch: data.DiskStructNormalized - ) -> Tuple["TrainState", experiment_files.TensorboardLogData]: + ) -> Tuple["TrainState", fifteen.experiments.TensorboardLogData]: """Single training step.""" # Quick shape check - (batch_size,) = batch.check_shapes_and_get_batch_axes() + (batch_size,) = batch.get_batch_axes() # Compute loss + backprop => apply gradient transforms => update parameters (loss, cnn_outputs), grads = jax.value_and_grad( @@ -86,7 +86,7 @@ def training_step( learnable_params_new = optax.apply_updates(self.learnable_params, updates) # Log data - log_data = experiment_files.TensorboardLogData( + log_data = fifteen.experiments.TensorboardLogData( scalars={ "train/training_loss": loss, "train/gradient_norm": optax.global_norm(grads), diff --git a/lib/disk/validation_fg.py b/lib/disk/validation_fg.py index dc98bdb..3944c41 100644 --- a/lib/disk/validation_fg.py +++ b/lib/disk/validation_fg.py @@ -38,7 +38,7 @@ def make_compute_metrics( ) -> Callable[[training_fg.TrainState], validation_tracker.ValidationMetrics]: eval_trajectories = data_loading.load_trajectories(train=False, fold=dataset_fold) - (trajectory_length,) = eval_trajectories[0].check_shapes_and_get_batch_axes() + (trajectory_length,) = eval_trajectories[0].get_batch_axes() graph_template = fg_utils.make_factor_graph(trajectory_length=trajectory_length) def compute_metrics( diff --git a/lib/experiment_files.py b/lib/experiment_files.py deleted file mode 100644 index 5059f43..0000000 --- a/lib/experiment_files.py +++ /dev/null @@ -1,222 +0,0 @@ -"""Simple experiment manager for managing metadata, logs, and checkpoints.""" - -import dataclasses -import pathlib -import shutil -from typing import Any, Dict, Optional, Type, TypeVar, Union, overload - -import flax.metrics.tensorboard -import flax.training.checkpoints -import jax_dataclasses -import yaml -from jax import numpy as jnp - -T = TypeVar("T") -PytreeType = TypeVar("PytreeType") -Pytree = Any - - -@jax_dataclasses.pytree_dataclass -class TensorboardLogData: - scalars: Dict[str, jnp.ndarray] = jax_dataclasses.field(default_factory=dict) - histograms: Dict[str, jnp.ndarray] = jax_dataclasses.field(default_factory=dict) - - @staticmethod - def merge(a: "TensorboardLogData", b: "TensorboardLogData") -> "TensorboardLogData": - return TensorboardLogData( - scalars=dict(**a.scalars, **b.scalars), - histograms=dict(**a.histograms, **b.histograms), - ) - - def extend( - self, - scalars: Dict[str, jnp.ndarray] = {}, - histograms: Dict[str, jnp.ndarray] = {}, - ): - return TensorboardLogData.merge( - self, - TensorboardLogData(scalars=scalars, histograms=histograms), - ) - - -@dataclasses.dataclass(frozen=True) -class ExperimentFiles: - """Helper class for locating checkpoints, logs, and any experiment metadata.""" - - identifier: str - verbose: bool = True - - # Generated in __post_init__ - data_dir: pathlib.Path = dataclasses.field(init=False) - - def __post_init__(self) -> None: - # Assign checkpoint + log directories - root = pathlib.Path("./experiments") - super().__setattr__("data_dir", root / self.identifier) - - def assert_new(self) -> "ExperimentFiles": - """Makes sure that there are no existing checkpoints, logs, or metadata. Returns - self.""" - assert not self.data_dir.exists() or tuple(self.data_dir.iterdir()) == () - return self - - def assert_exists(self) -> "ExperimentFiles": - """Makes sure that there are existing checkpoints, logs, or metadata. Returns - self.""" - assert self.data_dir.exists() and tuple(self.data_dir.iterdir()) != () - return self - - def clear(self) -> "ExperimentFiles": - """Delete all checkpoints, logs, and metadata associated with an experiment. - Returns self.""" - - def error_cb(func: Any, path: str, exc_info: Any) -> None: - """Error callback for shutil.rmtree.""" - self._print(f"Error deleting {path}") - - def delete_path(path: pathlib.Path, n: int = 5) -> None: - """Deletes a path, as well as up to `n` empty parent directories.""" - if not path.exists(): - return - - shutil.rmtree(path, onerror=error_cb) - self._print("Deleting", path) - - if n > 0 and len(list(path.parent.iterdir())) == 0: - delete_path(path.parent, n - 1) - - delete_path(self.data_dir) - - return self - - def move(self, new_identifier: str) -> "ExperimentFiles": - """Move all files corresponding to an experiment to a new identifier. Returns - updated ExperimentFiles object.""" - new_experiment = ExperimentFiles( - identifier=new_identifier, verbose=self.verbose - ) - - def move(src: pathlib.Path, dst=pathlib.Path) -> None: - if not src.exists(): - return - self._print("Moving {src} to {dst}") - shutil.move(src=str(src), dst=str(dst)) - - move(src=self.data_dir, dst=new_experiment.data_dir) - - return new_experiment - - def write_metadata(self, name: str, object: Any, overwrite: bool = True) -> None: - """Serialize an object as a yaml file, then save it to the experiment's metadata - directory.""" - self._ensure_directory_exists(self.data_dir) - assert not name.endswith(".yaml") - - path = self.data_dir / (name + ".yaml") - assert overwrite or not path.exists(), "Metadata file already exists!" - - self._print("Writing metadata to", path) - with open(path, "w") as file: - file.write(yaml.dump(object)) - - @overload - def read_metadata(self, name: str, expected_type: Type[T]) -> T: - ... - - @overload - def read_metadata(self, name: str, expected_type: None = None) -> Any: - ... - - def read_metadata( - self, name: str, expected_type: Optional[Type[T]] = None - ) -> Union[T, Any]: - """Load an object from the experiment's metadata directory.""" - path = self.data_dir / (name + ".yaml") - - self._print("Reading metadata from", path) - with open(path, "r") as file: - output = yaml.load( - file.read(), - Loader=yaml.Loader, # Unsafe loading! - ) - - assert expected_type is None or isinstance(output, expected_type) - return output - - def save_checkpoint( - self, - target: Pytree, - step: int, - prefix: str = "checkpoint_", - keep: int = 1, - ) -> str: - """Thin wrapper around flax's `save_checkpoint()` function. - Returns a file name, as a string.""" - self._ensure_directory_exists(self.data_dir) - filename = flax.training.checkpoints.save_checkpoint( - ckpt_dir=self.data_dir, - target=target, - step=step, - prefix=prefix, - keep=keep, - ) - self._print("Saved checkpoint to", filename) - return filename - - def restore_checkpoint( - self, - target: PytreeType, - step: Optional[int] = None, - prefix: str = "checkpoint_", - ) -> PytreeType: - """Thin wrapper around flax's `restore_checkpoint()` function.""" - state_dict = flax.training.checkpoints.restore_checkpoint( - ckpt_dir=self.data_dir, - target=None, # Allows us to assert that a checkpoint was actually found - step=step, - prefix=prefix, - ) - assert state_dict is not None, "No checkpoint found!" - return flax.serialization.from_state_dict(target, state_dict) - - @property - def summary_writer(self) -> flax.metrics.tensorboard.SummaryWriter: - """Helper for Tensorboard logging.""" - if not hasattr(self, "__summary_writer__"): - object.__setattr__( - self, - "__summary_writer__", - flax.metrics.tensorboard.SummaryWriter(log_dir=self.data_dir), - ) - return object.__getattribute__(self, "__summary_writer__") - - def log( - self, - log_data: TensorboardLogData, - step: int, - log_scalars_every_n: Optional[int] = None, - log_histograms_every_n: Optional[int] = None, - ): - """Logging helper for Tensorboard. Not jit-friendly. - - TODO: we should phase out either this interface or the host callback one. This - one could use some polish and requires more boilerplate, but is nice because - it's more explicit.""" - - if log_scalars_every_n is not None and step % log_scalars_every_n == 0: - for k, v in log_data.scalars.items(): - self.summary_writer.scalar(k, v, step=step) - if log_histograms_every_n is not None and step % log_histograms_every_n == 0: - for k, v in log_data.histograms.items(): - self.summary_writer.histogram(k, v, step=step) - - def _ensure_directory_exists(self, path: pathlib.Path) -> None: - """Helper for... ensuring that directories exist.""" - if not path.exists(): - path.mkdir(parents=True) - self._print(f"Made directory at {path}") - - def _print(self, *args, **kwargs) -> None: - """Prefixed printing helper. No-op if `verbose` is set to `False`.""" - if self.verbose: - print(f"[{type(self).__name__}-{self.identifier}]", *args, **kwargs) diff --git a/lib/kitti/data.py b/lib/kitti/data.py index b8055e8..5ef12a3 100644 --- a/lib/kitti/data.py +++ b/lib/kitti/data.py @@ -1,7 +1,7 @@ """Structures (pytrees) for working with KITTI data.""" import dataclasses -from typing import List, TypeVar +from typing import List, TypeVar, cast import jax import jax_dataclasses @@ -11,24 +11,36 @@ from jax import numpy as jnp from typing_extensions import Annotated -from .. import array_struct - _KittiStructType = TypeVar("_KittiStructType", bound="_KittiStruct") +null_array = cast(jnp.ndarray, None) +# ^Placeholder value to be used as a dataclass field default, to enable structs that +# contain only a partial set of values. +# +# An intuitive solution is to populate fields with a dummy default array like +# `jnp.empty(shape=(0,))`), but this can cause silent broadcasting/tracing issues. +# +# So instead we use `None` as the default value. Which is nice because it leads to loud +# runtime errors when uninitialized values are accidentally used. +# +# Note that the correct move would be to hint fields as `Optional[jnp.ndarray]`, but +# this would result in code that's littered with `assert __ is not None` statements +# and/or casts. Which is annoying. So instead we just pretend `None` is an array, + @jax_dataclasses.pytree_dataclass -class _KittiStruct(array_struct.ShapeAnnotatedStruct): +class _KittiStruct(jax_dataclasses.EnforcedAnnotationsMixin): """Base class for storing KITTI data, which can either be normalized or not.""" # Annotated[..., ...] attaches an expected shape to each field. (which may end up # being prefixed by a shared set of batch axes) - image: Annotated[jnp.ndarray, (50, 150, 3)] = array_struct.null_array - image_diff: Annotated[jnp.ndarray, (50, 150, 3)] = array_struct.null_array - x: Annotated[jnp.ndarray, ()] = array_struct.null_array - y: Annotated[jnp.ndarray, ()] = array_struct.null_array - theta: Annotated[jnp.ndarray, ()] = array_struct.null_array - linear_vel: Annotated[jnp.ndarray, ()] = array_struct.null_array - angular_vel: Annotated[jnp.ndarray, ()] = array_struct.null_array + image: Annotated[jnp.ndarray, (50, 150, 3), jnp.floating] = null_array + image_diff: Annotated[jnp.ndarray, (50, 150, 3), jnp.floating] = null_array + x: Annotated[jnp.ndarray, (), jnp.floating] = null_array + y: Annotated[jnp.ndarray, (), jnp.floating] = null_array + theta: Annotated[jnp.ndarray, (), jnp.floating] = null_array + linear_vel: Annotated[jnp.ndarray, (), jnp.floating] = null_array + angular_vel: Annotated[jnp.ndarray, (), jnp.floating] = null_array def get_stacked_velocity(self) -> jnp.ndarray: """Return 2-channel velocity.""" @@ -40,7 +52,7 @@ def get_stacked_velocity(self) -> jnp.ndarray: # Constants for data normalization. We again perform a lot of type abuse... -_DATASET_MEANS = _KittiStruct( +_DATASET_MEANS = dict( image=onp.array([88.91195932, 94.08863257, 92.80115751]), # type: ignore image_diff=onp.array([-0.00086295, -0.00065804, -0.00133435]), # type: ignore x=195.02545, # type: ignore @@ -50,7 +62,7 @@ def get_stacked_velocity(self) -> jnp.ndarray: angular_vel=-0.000439872, # type: ignore ) -_DATASET_STD_DEVS = _KittiStruct( +_DATASET_STD_DEVS = dict( image=onp.array([74.12011514, 76.13433045, 77.88847008]), # type: ignore image_diff=onp.array([38.63185147, 39.0655375, 38.7856255]), # type: ignore x=294.42093, # type: ignore @@ -73,8 +85,8 @@ def unnormalize(self, scale_only: bool = False) -> "KittiStructRaw": """Unnormalize contents.""" def _unnorm(value, mean, std): - if value is array_struct.null_array: - return array_struct.null_array + if value is null_array: + return null_array if scale_only: return value * std @@ -85,8 +97,8 @@ def _unnorm(value, mean, std): **jax.tree_map( _unnorm, vars(self), - vars(_DATASET_MEANS), - vars(_DATASET_STD_DEVS), + _DATASET_MEANS, + _DATASET_STD_DEVS, ) ) @@ -133,8 +145,8 @@ def normalize(self, scale_only: bool = False) -> "KittiStructNormalized": """Normalize contents.""" def _norm(value, mean, std): - if value is array_struct.null_array: - return array_struct.null_array + if value is null_array: + return null_array if scale_only: return value / std @@ -145,8 +157,8 @@ def _norm(value, mean, std): **jax.tree_map( _norm, vars(self), - vars(_DATASET_MEANS), - vars(_DATASET_STD_DEVS), + _DATASET_MEANS, + _DATASET_STD_DEVS, ) ) diff --git a/lib/kitti/data_loading.py b/lib/kitti/data_loading.py index bd65c4a..780f68a 100644 --- a/lib/kitti/data_loading.py +++ b/lib/kitti/data_loading.py @@ -8,7 +8,11 @@ import fannypack import jax import numpy as onp -import torch + +# For future projects, we probably want to use fifteen.data.DataLoader instead of the +# torch DataLoader, but keeping the torch one because that's what was used for the paper +# results. +import torch.utils.data from .. import utils from . import data, experiment_config diff --git a/lib/kitti/experiment_config.py b/lib/kitti/experiment_config.py index 4cb7f7a..42345d5 100644 --- a/lib/kitti/experiment_config.py +++ b/lib/kitti/experiment_config.py @@ -1,11 +1,13 @@ """Experiment configurations for KITTI task. -""" + +Note: we'd likely structure these very differently if we were to rewrite this code, +particularly to replace inheritance with nested dataclasses for common fields. (the +latter is now supported in `dcargs`)""" + import dataclasses import enum from typing import Literal, Union -import datargs - from .. import utils #################### @@ -92,17 +94,15 @@ class InitializationStrategyEnum(utils.StringEnum): NAIVE_BASELINE = enum.auto() -@datargs.argsclass(name="joint-nll-loss") @dataclasses.dataclass(frozen=True) class JointNllLossConfig: """Joint NLL loss configuration. - Empty dataclass to match syntax used for defining subparsers in datargs.""" + Empty dataclass to match syntax used for defining subparsers in dcargs.""" pass -@datargs.argsclass(name="surrogate-loss") @dataclasses.dataclass(frozen=True) class SurrogateLossConfig: """End-to-end surrogate loss configuration.""" diff --git a/lib/kitti/networks.py b/lib/kitti/networks.py index fa12608..b084160 100644 --- a/lib/kitti/networks.py +++ b/lib/kitti/networks.py @@ -2,12 +2,12 @@ from typing import Any, NamedTuple, Protocol, Tuple +import fifteen import jax import numpy as onp from flax import linen as nn from jax import numpy as jnp -from .. import experiment_files from . import data, experiment_config Pytree = Any @@ -151,7 +151,7 @@ def load_pretrained_observation_cnn( # Note that seed does not matter, because parameters will be read from checkpoint model, params = make_observation_cnn(random_seed=0) - experiment = experiment_files.ExperimentFiles(identifier=experiment_identifier) + experiment = fifteen.experiments.Experiment(identifier=experiment_identifier) params = experiment.restore_checkpoint(params, prefix="best_val_params_") return model, params diff --git a/lib/kitti/networks_lstm.py b/lib/kitti/networks_lstm.py index cfd307c..693ff71 100644 --- a/lib/kitti/networks_lstm.py +++ b/lib/kitti/networks_lstm.py @@ -20,7 +20,7 @@ class KittiLstm(nn.Module): @nn.compact def __call__(self, inputs: data.KittiStructNormalized, train: bool) -> jaxlie.SE2: - N, T = inputs.check_shapes_and_get_batch_axes() + N, T = inputs.get_batch_axes() stacked_images = inputs.get_stacked_image() assert stacked_images.shape == (N, T, 50, 150, 6) diff --git a/lib/kitti/training_ekf.py b/lib/kitti/training_ekf.py index 85eb971..e598d1a 100644 --- a/lib/kitti/training_ekf.py +++ b/lib/kitti/training_ekf.py @@ -1,12 +1,13 @@ from typing import Any, Optional, Tuple +import fifteen import jax import jax_dataclasses import jaxlie import optax from jax import numpy as jnp -from .. import experiment_files, manifold_ekf, utils +from .. import manifold_ekf, utils from . import data, experiment_config, fg_system, networks @@ -121,7 +122,7 @@ def initialize( @jax.jit def training_step( self, batched_trajectory: data.KittiStructNormalized - ) -> Tuple["TrainState", experiment_files.TensorboardLogData]: + ) -> Tuple["TrainState", fifteen.experiments.TensorboardLogData]: def compute_loss_single( learnable_params: Pytree, trajectory: data.KittiStructNormalized, @@ -162,7 +163,7 @@ def compute_loss(learnable_params: Pytree, prng_key: PRNGKey) -> jnp.ndarray: ) # Data to log - log_data = experiment_files.TensorboardLogData( + log_data = fifteen.experiments.TensorboardLogData( scalars={ "train/training_loss": loss, "train/gradient_norm": optax.global_norm(grads), @@ -184,7 +185,7 @@ def run_ekf( prng_key: PRNGKey, learnable_params: Optional[Pytree] = None, ) -> fg_system.State: - (_timesteps,) = trajectory.check_shapes_and_get_batch_axes() + (_timesteps,) = trajectory.get_batch_axes() # Some type aliases Belief = manifold_ekf.MultivariateGaussian[fg_system.State] @@ -217,7 +218,7 @@ def run_ekf( cov=jnp.eye(5) * 1e-7, # This can probably just be zeros ) - (timesteps,) = trajectory.check_shapes_and_get_batch_axes() + (timesteps,) = trajectory.get_batch_axes() def ekf_step( # carry diff --git a/lib/kitti/training_fg.py b/lib/kitti/training_fg.py index 4f3ff62..90110f9 100644 --- a/lib/kitti/training_fg.py +++ b/lib/kitti/training_fg.py @@ -1,13 +1,14 @@ """Training helpers for KITTI task.""" from typing import Any, Optional, Tuple +import fifteen import jax import jax_dataclasses import jaxfg import optax from jax import numpy as jnp -from .. import experiment_files, utils +from .. import utils from . import data, experiment_config, fg_losses, fg_utils, networks Pytree = Any @@ -156,11 +157,11 @@ def update_factor_graph( @jax.jit def training_step( self, batched_trajectory: data.KittiStructNormalized - ) -> Tuple["TrainState", experiment_files.TensorboardLogData]: + ) -> Tuple["TrainState", fifteen.experiments.TensorboardLogData]: """Single training step.""" # Shared leading axes should be (batch, timesteps) - assert len(batched_trajectory.check_shapes_and_get_batch_axes()) == 2 + assert len(batched_trajectory.get_batch_axes()) == 2 def compute_loss_single( learnable_params: LearnableParams, @@ -224,7 +225,7 @@ def compute_loss( # Log data regressed_velocities = per_sample_metadata.regressed_velocities regressed_uncertainties = per_sample_metadata.regressed_uncertainties - log_data = experiment_files.TensorboardLogData( + log_data = fifteen.experiments.TensorboardLogData( scalars={ "train/training_loss": loss, "train/gradient_norm": optax.global_norm(grads), diff --git a/lib/kitti/training_lstm.py b/lib/kitti/training_lstm.py index d7a594e..ea67f2b 100644 --- a/lib/kitti/training_lstm.py +++ b/lib/kitti/training_lstm.py @@ -1,12 +1,13 @@ from typing import Any, Tuple +import fifteen import jax import jax_dataclasses import jaxlie import optax from jax import numpy as jnp -from .. import experiment_files, utils +from .. import utils from . import data, experiment_config, math_utils, networks_lstm PRNGKey = jnp.ndarray # TODO: we should standardize PRNG vs Prng @@ -77,12 +78,12 @@ def initialize( @jax.jit def training_step( self, batch: data.KittiStructNormalized - ) -> Tuple["TrainState", experiment_files.TensorboardLogData]: + ) -> Tuple["TrainState", fifteen.experiments.TensorboardLogData]: def compute_loss( learnable_params: LearnableParams, prng_key: PRNGKey - ) -> Tuple[jnp.ndarray, experiment_files.TensorboardLogData]: + ) -> Tuple[jnp.ndarray, fifteen.experiments.TensorboardLogData]: """Compute average loss for all trajectories in the batch.""" - (_N, _T) = batch.check_shapes_and_get_batch_axes() + (_N, _T) = batch.get_batch_axes() batch_unnorm = batch.unnormalize() @@ -112,7 +113,7 @@ def compute_loss( training_loss = translation_loss + rotation_loss - return training_loss, experiment_files.TensorboardLogData( + return training_loss, fifteen.experiments.TensorboardLogData( scalars={ "train/translation_loss": translation_loss, "train/rotation_loss": rotation_loss, @@ -144,9 +145,9 @@ def compute_loss( ) # Data to log - log_data = experiment_files.TensorboardLogData.merge( + log_data = fifteen.experiments.TensorboardLogData.merge( compute_loss_log_data, - experiment_files.TensorboardLogData( + fifteen.experiments.TensorboardLogData( scalars={ "train/gradient_norm": optax.global_norm(grads), } diff --git a/lib/kitti/training_virtual_sensor.py b/lib/kitti/training_virtual_sensor.py index 64a54f5..3454a0a 100644 --- a/lib/kitti/training_virtual_sensor.py +++ b/lib/kitti/training_virtual_sensor.py @@ -1,11 +1,11 @@ from typing import Optional, Tuple +import fifteen import jax import jax_dataclasses import optax from jax import numpy as jnp -from .. import experiment_files from . import data, experiment_config, networks PRNGKey = jnp.ndarray @@ -113,7 +113,7 @@ def compute_loss( def training_step( self, batch: data.KittiStructNormalized, - ) -> Tuple["TrainState", experiment_files.TensorboardLogData]: + ) -> Tuple["TrainState", fifteen.experiments.TensorboardLogData]: """Single training step.""" # Quick shape check @@ -141,7 +141,7 @@ def training_step( angular_vel=cnn_output[:, 1], ).unnormalize() - log_data = experiment_files.TensorboardLogData( + log_data = fifteen.experiments.TensorboardLogData( scalars={ "train/training_loss": loss, "train/gradient_norm": optax.global_norm(grads), diff --git a/lib/kitti/validation_ekf.py b/lib/kitti/validation_ekf.py index 475ce4f..d1a1d79 100644 --- a/lib/kitti/validation_ekf.py +++ b/lib/kitti/validation_ekf.py @@ -21,7 +21,7 @@ def _compute_metrics( train_state: training_ekf.TrainState, trajectory: data.KittiStructNormalized, ) -> _ValidationMetrics: - (_timesteps,) = trajectory.check_shapes_and_get_batch_axes() + (_timesteps,) = trajectory.get_batch_axes() gt_trajectory_raw = trajectory.unnormalize() posterior_states = train_state.run_ekf( @@ -65,7 +65,7 @@ def compute_metrics( traj = eval_dataset[i] # Leading axes: (batch, # timesteps) - (timesteps,) = traj.check_shapes_and_get_batch_axes() + (timesteps,) = traj.get_batch_axes() batch_metrics = _compute_metrics( train_state, diff --git a/lib/kitti/validation_fg.py b/lib/kitti/validation_fg.py index fae8419..8037e3a 100644 --- a/lib/kitti/validation_fg.py +++ b/lib/kitti/validation_fg.py @@ -23,7 +23,7 @@ def _compute_metrics( trajectory: data.KittiStructNormalized, ) -> _ValidationMetrics: # Leading axes: (# timesteps,) - assert len(trajectory.check_shapes_and_get_batch_axes()) == 1 + assert len(trajectory.get_batch_axes()) == 1 graph, _unused = train_state.update_factor_graph( graph_template=graph_template, @@ -93,7 +93,7 @@ def compute_metrics( traj = eval_dataset[i] # Leading axes: (batch, # timesteps) - (timesteps,) = traj.check_shapes_and_get_batch_axes() + (timesteps,) = traj.get_batch_axes() batch_metrics = _compute_metrics( train_state, diff --git a/lib/kitti/validation_lstm.py b/lib/kitti/validation_lstm.py index fe6e67e..5156392 100644 --- a/lib/kitti/validation_lstm.py +++ b/lib/kitti/validation_lstm.py @@ -27,7 +27,7 @@ def compute_metrics( for batch in eval_dataloader: batch_unnorm = batch.unnormalize() - (N, T) = batch.check_shapes_and_get_batch_axes() + (N, T) = batch.get_batch_axes() regressed_poses: jaxlie.SE2 = jax.jit( train_state.lstm.apply, static_argnames=("train",) diff --git a/lib/lstm_layers.py b/lib/lstm_layers.py index a21480e..2757cd1 100644 --- a/lib/lstm_layers.py +++ b/lib/lstm_layers.py @@ -3,6 +3,7 @@ Borrows from: https://github.com/google/flax/blob/main/examples/sst2/models.py """ +import functools from typing import Tuple import jax @@ -13,7 +14,7 @@ class UniLstm(nn.Module): """A simple unidirectional LSTM.""" - @jax.partial( + @functools.partial( nn.transforms.scan, variable_broadcast="params", in_axes=1, diff --git a/lib/train_state_protocol.py b/lib/train_state_protocol.py index abe692e..93058dd 100644 --- a/lib/train_state_protocol.py +++ b/lib/train_state_protocol.py @@ -2,12 +2,13 @@ from typing import Any, Generic, Protocol, Tuple, TypeVar -from . import array_struct, experiment_files +import fifteen +import jax_dataclasses SelfType = TypeVar("SelfType") TrainingDataType = TypeVar( "TrainingDataType", - bound=array_struct.ShapeAnnotatedStruct, + bound=jax_dataclasses.EnforcedAnnotationsMixin, contravariant=True, ) Pytree = Any @@ -21,5 +22,5 @@ class TrainStateProtocol(Protocol, Generic[TrainingDataType]): def training_step( self: SelfType, batch: TrainingDataType - ) -> Tuple[SelfType, experiment_files.TensorboardLogData]: + ) -> Tuple[SelfType, fifteen.experiments.TensorboardLogData]: ... diff --git a/lib/utils.py b/lib/utils.py index 70889cb..48ba2c3 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -1,14 +1,11 @@ """Utilities shared by all tasks.""" -import argparse -import dataclasses import enum import pathlib import random -from typing import Any, Iterable, Optional, Type, TypeVar +from typing import Any, Iterable, TypeVar -import datargs -import fannypack +import fifteen import jax import numpy as onp import optax @@ -37,7 +34,7 @@ def value(self) -> str: def get_git_commit_hash() -> str: """Get current repository commit hash.""" - return fannypack.utils.get_git_commit_hash(str(pathlib.Path(__file__).parent)) + return fifteen.utils.get_git_commit_hash(pathlib.Path(__file__).parent) def warmup_schedule(learning_rate: float, warmup_steps: int) -> optax.Schedule: @@ -55,38 +52,3 @@ def set_random_seed(seed: int) -> None: def collate_fn(batch: Iterable[PytreeType], axis=0) -> PytreeType: """Collate function for torch DataLoaders.""" return jax.tree_multimap(lambda *arrays: onp.stack(arrays, axis=axis), *batch) - - -def parse_args( - cls: Type[DataclassType], *, description: Optional[str] = None -) -> DataclassType: - """Populates a dataclass via CLI args. Basically the same as `datargs.parse()`, but - adds default values to helptext.""" - assert dataclasses.is_dataclass(cls) - - # Modify helptext to add default values. - # - # This is a little bit prettier than using the argparse helptext formatter, which - # will include dataclass.MISSING values. - for field in dataclasses.fields(cls): - if field.default is not dataclasses.MISSING: - # Heuristic for if field has already been mutated. By default metadata will - # resolve to a mappingproxy object. - if isinstance(field.metadata, dict): - continue - - # Add default value to helptext! - if hasattr(field.default, "name"): - # Special case for enums - default_fmt = f"(default: {field.default.name})" - else: - default_fmt = "(default: %(default)s)" - - field.metadata = dict(field.metadata) - field.metadata["help"] = ( - f"{field.metadata['help']} {default_fmt}" - if "help" in field.metadata - else default_fmt - ) - - return datargs.parse(cls, parser=argparse.ArgumentParser(description=description)) diff --git a/lib/validation_tracker.py b/lib/validation_tracker.py index 0dc09b4..7f14fb7 100644 --- a/lib/validation_tracker.py +++ b/lib/validation_tracker.py @@ -4,7 +4,9 @@ import dataclasses from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar -from . import experiment_files, train_state_protocol +import fifteen + +from . import train_state_protocol Pytree = Any @@ -22,7 +24,7 @@ class ValidationTracker(Generic[TrainState]): """Helper for tracking+logging validation statistics.""" name: str - experiment: experiment_files.ExperimentFiles + experiment: fifteen.experiments.Experiment compute_metrics: Callable[[TrainState], ValidationMetrics] lowest_metric: Optional[float] = dataclasses.field(default=None) diff --git a/requirements.txt b/requirements.txt index dffee70..f6e42a8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,14 @@ jax==0.2.18 jaxlib==0.1.70 +flax==0.3.2 jaxlie +dcargs git+git://github.com/brentyi/jaxfg@master jax_dataclasses tqdm -datargs>=0.10.0 fannypack -torch termcolor # tensorflow is needed for tensorboard logging in JAX tensorflow types-termcolor +git+https://github.com/brentyi/fifteen.git@672446c3cd21ca35d6936a9fb0daa3cab9ba0a36 diff --git a/train_disk_ekf.py b/train_disk_ekf.py index 4e8b9fd..acca0a6 100644 --- a/train_disk_ekf.py +++ b/train_disk_ekf.py @@ -2,14 +2,15 @@ EKF, but since the dynamics and observation model are linear we end up with a standard Kalman filter.""" -import fannypack +import dcargs +import fifteen from tqdm.auto import tqdm -from lib import disk, experiment_files, utils, validation_tracker +from lib import disk, utils, validation_tracker def main(config: disk.experiment_config.EkfExperimentConfig) -> None: - experiment = experiment_files.ExperimentFiles( + experiment = fifteen.experiments.Experiment( identifier=config.experiment_identifier.format(dataset_fold=config.dataset_fold) ).clear() experiment.write_metadata("experiment_config", config) @@ -54,8 +55,8 @@ def main(config: disk.experiment_config.EkfExperimentConfig) -> None: if __name__ == "__main__": - fannypack.utils.pdb_safety_net() - config = utils.parse_args( + fifteen.utils.pdb_safety_net() + config = dcargs.parse( disk.experiment_config.EkfExperimentConfig, description=__doc__, ) diff --git a/train_disk_fg.py b/train_disk_fg.py index 6396938..ace6ad8 100644 --- a/train_disk_fg.py +++ b/train_disk_fg.py @@ -1,13 +1,14 @@ """Factor graph training script for visual tracking task.""" -import fannypack +import dcargs +import fifteen from tqdm.auto import tqdm -from lib import disk, experiment_files, utils, validation_tracker +from lib import disk, utils, validation_tracker def main(config: disk.experiment_config.FactorGraphExperimentConfig) -> None: - experiment = experiment_files.ExperimentFiles( + experiment = fifteen.experiments.Experiment( identifier=config.experiment_identifier.format(dataset_fold=config.dataset_fold) ).clear() experiment.write_metadata("experiment_config", config) @@ -52,8 +53,8 @@ def main(config: disk.experiment_config.FactorGraphExperimentConfig) -> None: if __name__ == "__main__": - fannypack.utils.pdb_safety_net() - config = utils.parse_args( + fifteen.utils.pdb_safety_net() + config = dcargs.parse( disk.experiment_config.FactorGraphExperimentConfig, description=__doc__, ) diff --git a/train_disk_lstm.py b/train_disk_lstm.py index 0cd1b57..cf0c3f9 100644 --- a/train_disk_lstm.py +++ b/train_disk_lstm.py @@ -1,13 +1,14 @@ """LSTM training script for visual tracking task.""" -import fannypack +import dcargs +import fifteen from tqdm.auto import tqdm -from lib import disk, experiment_files, utils, validation_tracker +from lib import disk, utils, validation_tracker def main(config: disk.experiment_config.LstmExperimentConfig) -> None: - experiment = experiment_files.ExperimentFiles( + experiment = fifteen.experiments.Experiment( identifier=config.experiment_identifier.format(dataset_fold=config.dataset_fold) ).clear() experiment.write_metadata("experiment_config", config) @@ -52,8 +53,8 @@ def main(config: disk.experiment_config.LstmExperimentConfig) -> None: if __name__ == "__main__": - fannypack.utils.pdb_safety_net() - config = utils.parse_args( + fifteen.utils.pdb_safety_net() + config = dcargs.parse( disk.experiment_config.LstmExperimentConfig, description=__doc__, ) diff --git a/train_disk_virtual_sensor.py b/train_disk_virtual_sensor.py index a791be5..ba5c629 100644 --- a/train_disk_virtual_sensor.py +++ b/train_disk_virtual_sensor.py @@ -2,16 +2,17 @@ import functools -import fannypack +import dcargs +import fifteen from tqdm.auto import tqdm -from lib import disk, experiment_files, utils, validation_tracker +from lib import disk, utils, validation_tracker def main( config: disk.experiment_config.VirtualSensorPretrainingExperimentConfig, ) -> None: - experiment = experiment_files.ExperimentFiles( + experiment = fifteen.experiments.Experiment( identifier=config.experiment_identifier.format(dataset_fold=config.dataset_fold) ).clear() experiment.write_metadata("experiment_config", config) @@ -65,8 +66,8 @@ def main( if __name__ == "__main__": - fannypack.utils.pdb_safety_net() - config = utils.parse_args( + fifteen.utils.pdb_safety_net() + config = dcargs.parse( disk.experiment_config.VirtualSensorPretrainingExperimentConfig, description=__doc__, ) diff --git a/train_kitti_ekf.py b/train_kitti_ekf.py index a99630c..2179f9f 100644 --- a/train_kitti_ekf.py +++ b/train_kitti_ekf.py @@ -1,12 +1,14 @@ """EKF end-to-end training script for visual odometry task.""" +import dcargs +import fifteen from tqdm.auto import tqdm -from lib import experiment_files, kitti, utils, validation_tracker +from lib import kitti, utils, validation_tracker def main(config: kitti.experiment_config.EkfExperimentConfig) -> None: - experiment = experiment_files.ExperimentFiles( + experiment = fifteen.experiments.Experiment( identifier=config.experiment_identifier.format(dataset_fold=config.dataset_fold) ).clear() experiment.write_metadata("experiment_config", config) @@ -53,5 +55,7 @@ def main(config: kitti.experiment_config.EkfExperimentConfig) -> None: if __name__ == "__main__": - config = utils.parse_args(kitti.experiment_config.EkfExperimentConfig) + config = dcargs.parse( + kitti.experiment_config.EkfExperimentConfig, description=__doc__ + ) main(config) diff --git a/train_kitti_fg.py b/train_kitti_fg.py index 55ec7a0..3feeb79 100644 --- a/train_kitti_fg.py +++ b/train_kitti_fg.py @@ -1,13 +1,14 @@ """Factor graph training script for visual odometry task.""" -import fannypack +import dcargs +import fifteen from tqdm.auto import tqdm -from lib import experiment_files, kitti, utils, validation_tracker +from lib import kitti, utils, validation_tracker def main(config: kitti.experiment_config.FactorGraphExperimentConfig) -> None: - experiment = experiment_files.ExperimentFiles( + experiment = fifteen.experiments.Experiment( identifier=config.experiment_identifier.format(dataset_fold=config.dataset_fold) ).clear() experiment.write_metadata("experiment_config", config) @@ -67,8 +68,8 @@ def main(config: kitti.experiment_config.FactorGraphExperimentConfig) -> None: if __name__ == "__main__": - fannypack.utils.pdb_safety_net() - config = utils.parse_args( + fifteen.utils.pdb_safety_net() + config = dcargs.parse( kitti.experiment_config.FactorGraphExperimentConfig, description=__doc__, ) diff --git a/train_kitti_lstm.py b/train_kitti_lstm.py index 15013a8..3655b39 100644 --- a/train_kitti_lstm.py +++ b/train_kitti_lstm.py @@ -1,13 +1,14 @@ """LSTM training script for visual odometry task.""" -import fannypack +import dcargs +import fifteen from tqdm.auto import tqdm -from lib import experiment_files, kitti, utils, validation_tracker +from lib import kitti, utils, validation_tracker def main(config: kitti.experiment_config.LstmExperimentConfig) -> None: - experiment = experiment_files.ExperimentFiles( + experiment = fifteen.experiments.Experiment( identifier=config.experiment_identifier.format(dataset_fold=config.dataset_fold) ).clear() experiment.write_metadata("experiment_config", config) @@ -56,6 +57,8 @@ def main(config: kitti.experiment_config.LstmExperimentConfig) -> None: if __name__ == "__main__": - fannypack.utils.pdb_safety_net() - config = utils.parse_args(kitti.experiment_config.LstmExperimentConfig) + fifteen.utils.pdb_safety_net() + config = dcargs.parse( + kitti.experiment_config.LstmExperimentConfig, description=__doc__ + ) main(config) diff --git a/train_kitti_virtual_sensor.py b/train_kitti_virtual_sensor.py index 4cbfa6e..809f2ae 100644 --- a/train_kitti_virtual_sensor.py +++ b/train_kitti_virtual_sensor.py @@ -1,10 +1,10 @@ """Pre-training script for visual odometry task virtual sensors.""" - -import fannypack +import dcargs +import fifteen from jax import numpy as jnp from tqdm.auto import tqdm -from lib import experiment_files, kitti, utils, validation_tracker +from lib import kitti, utils, validation_tracker PRNGKey = jnp.ndarray @@ -12,7 +12,7 @@ def main( config: kitti.experiment_config.VirtualSensorPretrainingExperimentConfig, ) -> None: - experiment = experiment_files.ExperimentFiles( + experiment = fifteen.experiments.Experiment( identifier=config.experiment_identifier.format(dataset_fold=config.dataset_fold) ).clear() experiment.write_metadata("experiment_config", config) @@ -68,8 +68,8 @@ def main( if __name__ == "__main__": - fannypack.utils.pdb_safety_net() - config = utils.parse_args( + fifteen.utils.pdb_safety_net() + config = dcargs.parse( kitti.experiment_config.VirtualSensorPretrainingExperimentConfig, description=__doc__, )