From 182cc4356096d49bcea2ecf7c629dc4bb91967b9 Mon Sep 17 00:00:00 2001 From: Eric Brown Date: Mon, 17 Jun 2024 20:04:56 -0600 Subject: [PATCH] Fix type hints --- tinyphysics.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tinyphysics.py b/tinyphysics.py index f73df3f..b68d61b 100644 --- a/tinyphysics.py +++ b/tinyphysics.py @@ -11,7 +11,7 @@ from functools import partial from hashlib import md5 from pathlib import Path -from typing import List, Union, Tuple +from typing import List, Union, Tuple, Dict from tqdm.contrib.concurrent import process_map from controllers import BaseController @@ -42,14 +42,14 @@ def __init__(self): self.vocab_size = VOCAB_SIZE self.bins = np.linspace(LATACCEL_RANGE[0], LATACCEL_RANGE[1], self.vocab_size) - def encode(self, value: Union[float, np.ndarray]) -> Union[int, np.ndarray]: + def encode(self, value: Union[float, np.ndarray, List[float]]) -> Union[int, np.ndarray]: value = self.clip(value) return np.digitize(value, self.bins, right=True) def decode(self, token: Union[int, np.ndarray]) -> Union[float, np.ndarray]: return self.bins[token] - def clip(self, value: Union[float, np.ndarray]) -> Union[float, np.ndarray]: + def clip(self, value: Union[float, np.ndarray, List[float]]) -> Union[float, np.ndarray]: return np.clip(value, LATACCEL_RANGE[0], LATACCEL_RANGE[1]) @@ -142,7 +142,7 @@ def control_step(self, step_idx: int) -> None: action = np.clip(action, STEER_RANGE[0], STEER_RANGE[1]) self.action_history.append(action) - def get_state_target_futureplan(self, step_idx: int) -> Tuple[State, float]: + def get_state_target_futureplan(self, step_idx: int) -> Tuple[State, float, FuturePlan]: state = self.data.iloc[step_idx] return ( State(roll_lataccel=state['roll_lataccel'], v_ego=state['v_ego'], a_ego=state['a_ego']), @@ -174,7 +174,7 @@ def plot_data(self, ax, lines, axis_labels, title) -> None: ax.set_xlabel(axis_labels[0]) ax.set_ylabel(axis_labels[1]) - def compute_cost(self) -> dict: + def compute_cost(self) -> Dict[str, float]: target = np.array(self.target_lataccel_history)[CONTROL_START_IDX:COST_END_IDX] pred = np.array(self.current_lataccel_history)[CONTROL_START_IDX:COST_END_IDX] @@ -183,7 +183,7 @@ def compute_cost(self) -> dict: total_cost = (lat_accel_cost * LAT_ACCEL_COST_MULTIPLIER) + jerk_cost return {'lataccel_cost': lat_accel_cost, 'jerk_cost': jerk_cost, 'total_cost': total_cost} - def rollout(self) -> float: + def rollout(self) -> Dict[str, float]: if self.debug: plt.ion() fig, ax = plt.subplots(4, figsize=(12, 14), constrained_layout=True)