-
Notifications
You must be signed in to change notification settings - Fork 10
/
train_kitti_fg.py
79 lines (65 loc) · 2.79 KB
/
train_kitti_fg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""Factor graph training script for visual odometry task."""
import pathlib
import fifteen
import tyro
from tqdm.auto import tqdm
from lib import kitti, utils, validation_tracker
def main(config: kitti.experiment_config.FactorGraphExperimentConfig) -> None:
experiment = fifteen.experiments.Experiment(
data_dir=pathlib.Path("./experiments/")
/ config.experiment_identifier.format(dataset_fold=config.dataset_fold)
).clear()
experiment.write_metadata("experiment_config", config)
experiment.write_metadata("git_commit_hash", utils.get_git_commit_hash())
# Set random seed (for everything but JAX)
utils.set_random_seed(config.random_seed)
# Load dataset
train_dataloader = kitti.data_loading.make_subsequence_dataloader(
# We use a chunk of the train set that was held out during pretraining -- this
# will result in uncertainties that generalize better to unseen data.
config,
split=kitti.data_loading.DatasetSplit.TRAIN_VIRTUAL_SENSOR_HOLDOUT,
)
val_dataset = kitti.data_loading.make_subsequence_eval_dataset(config)
# Helper for validation + metric-aware checkpointing
validation = validation_tracker.ValidationTracker[kitti.training_fg.TrainState](
name="val",
experiment=experiment,
compute_metrics=kitti.validation_fg.make_compute_metrics(val_dataset),
)
# Train
train_state = kitti.training_fg.TrainState.initialize(config, train=True)
for epoch in tqdm(range(config.num_epochs)):
batch: kitti.data.KittiStructNormalized
for batch in train_dataloader:
# Validation + checkpointing
if train_state.steps % 50 == 0:
validation = validation.validate_log_and_checkpoint_if_best(train_state)
# Training step!
train_state, log_data = train_state.training_step(batch)
# Log to Tensorboard
experiment.log(
log_data,
step=train_state.steps,
log_scalars_every_n=10,
log_histograms_every_n=100,
)
# Simple early stopping
# (this was added because we were on a time crunch, and validation metrics often
# plateau or demonstrate overfitting very early in the train cycle)
if (
(epoch > config.num_epochs // 3)
and (validation.best_step is not None)
and (validation.best_step <= int(train_state.steps) // 2)
):
print(
f"Early stopping: {validation.best_step=}, {epoch=}, {int(train_state.steps)=}"
)
break
if __name__ == "__main__":
fifteen.utils.pdb_safety_net()
config = tyro.cli(
kitti.experiment_config.FactorGraphExperimentConfig,
description=__doc__,
)
main(config)