-
Notifications
You must be signed in to change notification settings - Fork 10
/
kitti_transfer_ablation.py
84 lines (70 loc) · 3.06 KB
/
kitti_transfer_ablation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import dataclasses
import pathlib
import fifteen
import jax_dataclasses
import tyro
from tqdm.auto import tqdm
from lib import kitti
@dataclasses.dataclass
class Args:
fg_experiment_identifier: str = (
"kitti/fg/hetero/surrogate_pos/ground_truth-5/fold_{dataset_fold}"
)
ekf_experiment_identifier: str = "kitti/ekf/hetero/fold_{dataset_fold}"
def main(args: Args) -> None:
for dataset_fold in tqdm(range(10)):
# Experiments to transfer noise models across
ekf_experiment = fifteen.experiments.Experiment(
data_dir=pathlib.Path("./experiments/")
/ args.ekf_experiment_identifier.format(dataset_fold=dataset_fold)
).assert_exists()
fg_experiment = fifteen.experiments.Experiment(
data_dir=pathlib.Path("./experiments/")
/ args.fg_experiment_identifier.format(dataset_fold=dataset_fold)
).assert_exists()
# Read experiment configurations for each experiment
ekf_config = ekf_experiment.read_metadata(
"experiment_config", kitti.experiment_config.EkfExperimentConfig
)
fg_config = fg_experiment.read_metadata(
"experiment_config", kitti.experiment_config.FactorGraphExperimentConfig
)
# Initialize training states
ekf_train_state = kitti.training_ekf.TrainState.initialize(
ekf_config, train=False
)
fg_train_state = kitti.training_fg.TrainState.initialize(fg_config, train=False)
# Load uncertainty models... but swapped!
with jax_dataclasses.copy_and_mutate(
ekf_train_state, validate=False
) as ekf_train_state:
ekf_train_state.learnable_params = fg_experiment.restore_checkpoint(
target=ekf_train_state.learnable_params,
prefix="best_val_params_",
)
with jax_dataclasses.copy_and_mutate(
fg_train_state, validate=False
) as fg_train_state:
fg_train_state.learnable_params = ekf_experiment.restore_checkpoint(
target=fg_train_state.learnable_params,
prefix="best_val_params_",
)
# Evaluate each training state
_, ekf_metrics = kitti.validation_ekf.make_compute_metrics(
eval_dataset=kitti.data_loading.make_subsequence_eval_dataset(ekf_config)
)(ekf_train_state)
_, fg_metrics = kitti.validation_fg.make_compute_metrics(
eval_dataset=kitti.data_loading.make_subsequence_eval_dataset(fg_config)
)(fg_train_state)
# Write metrics to kitti
fifteen.experiments.Experiment(
data_dir=pathlib.Path("./experiments/")
/ f"kitti/ekf/hetero/trained_on_fg/fold_{dataset_fold}"
).clear().write_metadata("best_val_metrics", ekf_metrics)
fifteen.experiments.Experiment(
data_dir=pathlib.Path("./experiments/")
/ f"kitti/fg/hetero/trained_on_ekf/fold_{dataset_fold}"
).clear().write_metadata("best_val_metrics", fg_metrics)
if __name__ == "__main__":
args = tyro.cli(Args)
main(args)