-
Notifications
You must be signed in to change notification settings - Fork 10
/
kitti_compute_dynamics_rmses.py
75 lines (60 loc) · 2.37 KB
/
kitti_compute_dynamics_rmses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""Script for printing the expected RMSE of the analytical dynamics model.
Uses entire dataset, so this is a little bit wrong in that it leaks test set information
into training. But this doesn't impact results in practice: the learned values (5
scalars) should only used for initialization, and end up being tuned end-to-end (on only
the train set)."""
import jax
import jaxlie
import numpy as onp
from jax import numpy as jnp
from tqdm.auto import tqdm
from lib import kitti
Scalar = jnp.ndarray
@jax.jit
def compute_subsequence_sum_squared_error(
subsequence: kitti.data.KittiStructRaw,
) -> Scalar:
# Quick shape checks
(timesteps,) = subsequence.x.shape
assert subsequence.y.shape == subsequence.theta.shape == (timesteps,)
# Make stacked state object + do predictions
states: kitti.fg_system.State
states = jax.vmap(kitti.fg_system.State.make)(
pose=jax.vmap(jaxlie.SE2.from_xy_theta)(
subsequence.x, subsequence.y, subsequence.theta
),
linear_vel=subsequence.linear_vel,
angular_vel=subsequence.angular_vel,
)
predicted_states: kitti.fg_system.State
predicted_states = jax.vmap(kitti.fg_system.State.predict_next)(states)
# Align true states and predicted states
states = jax.tree_map(lambda x: x[1:], states)
predicted_states = jax.tree_map(lambda x: x[:-1], predicted_states)
# Compute squared errors
squared_errors = (
jax.vmap(kitti.fg_system.State.manifold_minus)(states, predicted_states) ** 2
)
assert squared_errors.shape == (timesteps - 1, 5)
return jnp.sum(squared_errors, axis=0)
def main() -> None:
print("Loading data...")
trajectories = kitti.data_loading.load_trajectories_from_ids(
# Note: for the kitti-10 dataset, we should really be skipping trajectory #1.
# But this is not really important, because the RMSEs we compute are only being
# used for initialization.
range(11),
verbose=False,
)
subsequences = kitti.data_loading.make_disjoint_subsequences(
trajectories, subsequence_length=100
)
print("Computing statistics...")
sse = onp.zeros(5)
for subsequence in tqdm(subsequences):
sse = sse + compute_subsequence_sum_squared_error(subsequence)
mse = sse / len(subsequences)
rmse = onp.sqrt(mse)
print("RMSE:", rmse)
if __name__ == "__main__":
main()