From c1905bd8a34948e2eb6553eeb80fca2ee084c7b1 Mon Sep 17 00:00:00 2001 From: pengzhenghao Date: Thu, 5 Dec 2024 19:57:42 -0800 Subject: [PATCH 1/8] Introduce ScenarioOnlineDataManager and ScenarioOnlineEnv --- metadrive/envs/scenario_env.py | 10 +- metadrive/manager/scenario_data_manager.py | 169 +++++++++++++++++++++ 2 files changed, 178 insertions(+), 1 deletion(-) diff --git a/metadrive/envs/scenario_env.py b/metadrive/envs/scenario_env.py index 1ba065862..bfe798b15 100644 --- a/metadrive/envs/scenario_env.py +++ b/metadrive/envs/scenario_env.py @@ -10,7 +10,7 @@ from metadrive.envs.base_env import BaseEnv from metadrive.manager.scenario_agent_manager import ScenarioAgentManager from metadrive.manager.scenario_curriculum_manager import ScenarioCurriculumManager -from metadrive.manager.scenario_data_manager import ScenarioDataManager +from metadrive.manager.scenario_data_manager import ScenarioDataManager, ScenarioOnlineDataManager from metadrive.manager.scenario_light_manager import ScenarioLightManager from metadrive.manager.scenario_map_manager import ScenarioMapManager from metadrive.manager.scenario_traffic_manager import ScenarioTrafficManager @@ -393,6 +393,14 @@ def _reset_global_seed(self, force_seed=None): self.seed(current_seed) +class ScenarioOnlineEnv(ScenarioEnv): + """ + This environment allow the user to pass in scenario data directly. + """ + def setup_engine(self): + self.engine.update_manager("data_manager", ScenarioOnlineDataManager()) + + if __name__ == "__main__": env = ScenarioEnv( { diff --git a/metadrive/manager/scenario_data_manager.py b/metadrive/manager/scenario_data_manager.py index 8b785efce..9812bb719 100644 --- a/metadrive/manager/scenario_data_manager.py +++ b/metadrive/manager/scenario_data_manager.py @@ -172,3 +172,172 @@ def destroy(self): self.summary_lookup.clear() self.mapping.clear() self.summary_dict, self.summary_lookup, self.mapping = None, None, None + + + +class ScenarioOnlineDataManager(BaseManager): + """ + Compared to ScenarioDataManager, this manager allow user to pass in Scenario Description online. + It will not read data from disk, but receive data from user. + """ + # DEFAULT_DATA_BUFFER_SIZE = 100 + PRIORITY = -10 + + _scenario = None + + # def __init__(self): + # super(ScenarioOnlineDataManager, self).__init__() + # from metadrive.engine.engine_utils import get_engine + # engine = get_engine() + + # self.store_data = engine.global_config["store_data"] + # self.directory = engine.global_config["data_directory"] + # self.num_scenarios = engine.global_config["num_scenarios"] + # self.start_scenario_index = engine.global_config["start_scenario_index"] + + # for multi-worker + # self.worker_index = self.engine.global_config["worker_index"] + # self.available_scenario_indices = [ + # i for i in range( + # self.start_scenario_index + self.worker_index, self.start_scenario_index + + # self.num_scenarios, self.engine.global_config["num_workers"] + # ) + # ] + # self._scenarios = {} + + # Read summary file first: + # self.summary_dict, self.summary_lookup, self.mapping = read_dataset_summary(self.directory) + # self.summary_lookup[:self.start_scenario_index] = [None] * self.start_scenario_index + # end_idx = self.start_scenario_index + self.num_scenarios + # self.summary_lookup[end_idx:] = [None] * (len(self.summary_lookup) - end_idx) + + # sort scenario for curriculum training + # self.scenario_difficulty = None + # self.sort_scenarios() + + # existence check + # assert self.start_scenario_index < len(self.summary_lookup), "Insufficient scenarios!" + # assert self.start_scenario_index + self.num_scenarios <= len(self.summary_lookup), \ + # "Insufficient scenarios! Need: {} Has: {}".format(self.num_scenarios, + # len(self.summary_lookup) - self.start_scenario_index) + # + # for p in self.summary_lookup[self.start_scenario_index:end_idx]: + # p = os.path.join(self.directory, self.mapping[p], p) + # assert os.path.exists(p), "No Data at path: {}".format(p) + # + # # stat + # self.coverage = [0 for _ in range(self.num_scenarios)] + + @property + def current_scenario_summary(self): + return self.current_scenario[SD.METADATA] + + # def _get_scenario(self, i): + # assert i in self.available_scenario_indices, \ + # "scenario index exceeds range, scenario index: {}, worker_index: {}".format(i, self.worker_index) + # assert i < len(self.summary_lookup) + # scenario_id = self.summary_lookup[i] + # file_path = os.path.join(self.directory, self.mapping[scenario_id], scenario_id) + # ret = read_scenario_data(file_path, centralize=True) + # assert isinstance(ret, SD) + # return ret + + # def before_reset(self): + # if not self.store_data: + # assert len(self._scenarios) <= 1, "It seems you access multiple scenarios in one episode" + # self._scenarios = {} + + def set_scenario(self, scenario_description): + SD.sanity_check(scenario_description) + self._scenario = scenario_description + + def get_scenario(self, seed=None): + return self._scenario + + def get_metadata(self): + raise ValueError() + state = super(ScenarioDataManager, self).get_metadata() + raw_data = self.current_scenario + state["raw_data"] = raw_data + return state + + @property + def current_scenario_length(self): + return self.current_scenario[SD.LENGTH] + + @property + def current_scenario(self): + return self._scenario + + # def sort_scenarios(self): + # + # """ + # TODO(LQY): consider exposing this API to config + # Sort scenarios to support curriculum training. You are encouraged to customize your own sort method + # :return: sorted scenario list + # """ + # if self.engine.max_level == 0: + # raise ValueError("Curriculum Level should be greater than 1") + # elif self.engine.max_level == 1: + # return + # + # def _score(scenario_id): + # file_path = os.path.join(self.directory, self.mapping[scenario_id], scenario_id) + # scenario = read_scenario_data(file_path, centralize=True) + # obj_weight = 0 + # + # # calculate curvature + # ego_car_id = scenario[SD.METADATA][SD.SDC_ID] + # state_dict = scenario["tracks"][ego_car_id]["state"] + # valid_track = state_dict["position"][np.where(state_dict["valid"].astype(int))][..., :2] + # + # dir = valid_track[1:] - valid_track[:-1] + # dir = np.arctan2(dir[..., 1], dir[..., 0]) + # curvature = sum(abs(dir[1:] - dir[:-1]) / np.pi) + 1 + # + # sdc_moving_dist = SD.sdc_moving_dist(scenario) + # num_moving_objs = SD.num_moving_object(scenario, object_type=MetaDriveType.VEHICLE) + # return sdc_moving_dist * curvature + num_moving_objs * obj_weight, scenario + # + # start = self.start_scenario_index + # end = self.start_scenario_index + self.num_scenarios + # id_score_scenarios = [(s_id, *_score(s_id)) for s_id in self.summary_lookup[start:end]] + # id_score_scenarios = sorted(id_score_scenarios, key=lambda scenario: scenario[-2]) + # self.summary_lookup[start:end] = [id_score_scenario[0] for id_score_scenario in id_score_scenarios] + # self.scenario_difficulty = { + # id_score_scenario[0]: id_score_scenario[1] + # for id_score_scenario in id_score_scenarios + # } + # self._scenarios = {i + start: id_score_scenario[-1] for i, id_score_scenario in enumerate(id_score_scenarios)} + # + # def clear_stored_scenarios(self): + # self._scenarios = {} + # + # @property + # def current_scenario_difficulty(self): + # return self.scenario_difficulty[self.summary_lookup[self.engine.global_random_seed] + # ] if self.scenario_difficulty is not None else 0 + + @property + def current_scenario_id(self): + return self.current_scenario_summary["scenario_id"] + + # @property + # def current_scenario_file_name(self): + # return self.summary_lookup[self.engine.global_random_seed] + # + # @property + # def data_coverage(self): + # return sum(self.coverage) / len(self.coverage) * self.engine.global_config["num_workers"] + + def destroy(self): + """ + Clear memory + """ + super(ScenarioOnlineDataManager, self).destroy() + self._scenario = None + # self._scenarios = {} + # Config.clear_nested_dict(self.summary_dict) + # self.summary_lookup.clear() + # self.mapping.clear() + # self.summary_dict, self.summary_lookup, self.mapping = None, None, None From 214b21fba495080e78d4fa52a99acd898860811c Mon Sep 17 00:00:00 2001 From: pengzhenghao Date: Thu, 5 Dec 2024 20:16:06 -0800 Subject: [PATCH 2/8] Prepare the ScenarioOnlineEnv --- metadrive/envs/scenario_env.py | 10 ++ metadrive/manager/scenario_data_manager.py | 131 ++---------------- .../test_env/test_scenario_online_env.py | 49 +++++++ 3 files changed, 68 insertions(+), 122 deletions(-) create mode 100644 metadrive/tests/test_env/test_scenario_online_env.py diff --git a/metadrive/envs/scenario_env.py b/metadrive/envs/scenario_env.py index bfe798b15..094973425 100644 --- a/metadrive/envs/scenario_env.py +++ b/metadrive/envs/scenario_env.py @@ -397,9 +397,19 @@ class ScenarioOnlineEnv(ScenarioEnv): """ This environment allow the user to pass in scenario data directly. """ + def __init__(self, config=None): + super(ScenarioOnlineEnv, self).__init__(config) + self.lazy_init() + def setup_engine(self): + """Overwrite the data_manager by ScenarioOnlineDataManager""" + super().setup_engine() self.engine.update_manager("data_manager", ScenarioOnlineDataManager()) + def set_scenario(self, scenario_data): + """Please call this function before env.reset()""" + self.engine.data_manager.set_scenario(scenario_data) + if __name__ == "__main__": env = ScenarioEnv( diff --git a/metadrive/manager/scenario_data_manager.py b/metadrive/manager/scenario_data_manager.py index 9812bb719..779acc2bd 100644 --- a/metadrive/manager/scenario_data_manager.py +++ b/metadrive/manager/scenario_data_manager.py @@ -174,84 +174,25 @@ def destroy(self): self.summary_dict, self.summary_lookup, self.mapping = None, None, None - class ScenarioOnlineDataManager(BaseManager): """ Compared to ScenarioDataManager, this manager allow user to pass in Scenario Description online. It will not read data from disk, but receive data from user. """ - # DEFAULT_DATA_BUFFER_SIZE = 100 PRIORITY = -10 - _scenario = None - # def __init__(self): - # super(ScenarioOnlineDataManager, self).__init__() - # from metadrive.engine.engine_utils import get_engine - # engine = get_engine() - - # self.store_data = engine.global_config["store_data"] - # self.directory = engine.global_config["data_directory"] - # self.num_scenarios = engine.global_config["num_scenarios"] - # self.start_scenario_index = engine.global_config["start_scenario_index"] - - # for multi-worker - # self.worker_index = self.engine.global_config["worker_index"] - # self.available_scenario_indices = [ - # i for i in range( - # self.start_scenario_index + self.worker_index, self.start_scenario_index + - # self.num_scenarios, self.engine.global_config["num_workers"] - # ) - # ] - # self._scenarios = {} - - # Read summary file first: - # self.summary_dict, self.summary_lookup, self.mapping = read_dataset_summary(self.directory) - # self.summary_lookup[:self.start_scenario_index] = [None] * self.start_scenario_index - # end_idx = self.start_scenario_index + self.num_scenarios - # self.summary_lookup[end_idx:] = [None] * (len(self.summary_lookup) - end_idx) - - # sort scenario for curriculum training - # self.scenario_difficulty = None - # self.sort_scenarios() - - # existence check - # assert self.start_scenario_index < len(self.summary_lookup), "Insufficient scenarios!" - # assert self.start_scenario_index + self.num_scenarios <= len(self.summary_lookup), \ - # "Insufficient scenarios! Need: {} Has: {}".format(self.num_scenarios, - # len(self.summary_lookup) - self.start_scenario_index) - # - # for p in self.summary_lookup[self.start_scenario_index:end_idx]: - # p = os.path.join(self.directory, self.mapping[p], p) - # assert os.path.exists(p), "No Data at path: {}".format(p) - # - # # stat - # self.coverage = [0 for _ in range(self.num_scenarios)] - @property def current_scenario_summary(self): return self.current_scenario[SD.METADATA] - # def _get_scenario(self, i): - # assert i in self.available_scenario_indices, \ - # "scenario index exceeds range, scenario index: {}, worker_index: {}".format(i, self.worker_index) - # assert i < len(self.summary_lookup) - # scenario_id = self.summary_lookup[i] - # file_path = os.path.join(self.directory, self.mapping[scenario_id], scenario_id) - # ret = read_scenario_data(file_path, centralize=True) - # assert isinstance(ret, SD) - # return ret - - # def before_reset(self): - # if not self.store_data: - # assert len(self._scenarios) <= 1, "It seems you access multiple scenarios in one episode" - # self._scenarios = {} - def set_scenario(self, scenario_description): SD.sanity_check(scenario_description) self._scenario = scenario_description - def get_scenario(self, seed=None): + def get_scenario(self, seed=None, should_copy=False): + if should_copy: + return copy.deepcopy(self._scenario) return self._scenario def get_metadata(self): @@ -269,66 +210,17 @@ def current_scenario_length(self): def current_scenario(self): return self._scenario - # def sort_scenarios(self): - # - # """ - # TODO(LQY): consider exposing this API to config - # Sort scenarios to support curriculum training. You are encouraged to customize your own sort method - # :return: sorted scenario list - # """ - # if self.engine.max_level == 0: - # raise ValueError("Curriculum Level should be greater than 1") - # elif self.engine.max_level == 1: - # return - # - # def _score(scenario_id): - # file_path = os.path.join(self.directory, self.mapping[scenario_id], scenario_id) - # scenario = read_scenario_data(file_path, centralize=True) - # obj_weight = 0 - # - # # calculate curvature - # ego_car_id = scenario[SD.METADATA][SD.SDC_ID] - # state_dict = scenario["tracks"][ego_car_id]["state"] - # valid_track = state_dict["position"][np.where(state_dict["valid"].astype(int))][..., :2] - # - # dir = valid_track[1:] - valid_track[:-1] - # dir = np.arctan2(dir[..., 1], dir[..., 0]) - # curvature = sum(abs(dir[1:] - dir[:-1]) / np.pi) + 1 - # - # sdc_moving_dist = SD.sdc_moving_dist(scenario) - # num_moving_objs = SD.num_moving_object(scenario, object_type=MetaDriveType.VEHICLE) - # return sdc_moving_dist * curvature + num_moving_objs * obj_weight, scenario - # - # start = self.start_scenario_index - # end = self.start_scenario_index + self.num_scenarios - # id_score_scenarios = [(s_id, *_score(s_id)) for s_id in self.summary_lookup[start:end]] - # id_score_scenarios = sorted(id_score_scenarios, key=lambda scenario: scenario[-2]) - # self.summary_lookup[start:end] = [id_score_scenario[0] for id_score_scenario in id_score_scenarios] - # self.scenario_difficulty = { - # id_score_scenario[0]: id_score_scenario[1] - # for id_score_scenario in id_score_scenarios - # } - # self._scenarios = {i + start: id_score_scenario[-1] for i, id_score_scenario in enumerate(id_score_scenarios)} - # - # def clear_stored_scenarios(self): - # self._scenarios = {} - # - # @property - # def current_scenario_difficulty(self): - # return self.scenario_difficulty[self.summary_lookup[self.engine.global_random_seed] - # ] if self.scenario_difficulty is not None else 0 + @property + def current_scenario_difficulty(self): + return 0 @property def current_scenario_id(self): return self.current_scenario_summary["scenario_id"] - # @property - # def current_scenario_file_name(self): - # return self.summary_lookup[self.engine.global_random_seed] - # - # @property - # def data_coverage(self): - # return sum(self.coverage) / len(self.coverage) * self.engine.global_config["num_workers"] + @property + def data_coverage(self): + return None def destroy(self): """ @@ -336,8 +228,3 @@ def destroy(self): """ super(ScenarioOnlineDataManager, self).destroy() self._scenario = None - # self._scenarios = {} - # Config.clear_nested_dict(self.summary_dict) - # self.summary_lookup.clear() - # self.mapping.clear() - # self.summary_dict, self.summary_lookup, self.mapping = None, None, None diff --git a/metadrive/tests/test_env/test_scenario_online_env.py b/metadrive/tests/test_env/test_scenario_online_env.py new file mode 100644 index 000000000..ea2914946 --- /dev/null +++ b/metadrive/tests/test_env/test_scenario_online_env.py @@ -0,0 +1,49 @@ +import pytest +import seaborn as sns +import numpy as np + +from metadrive.engine.asset_loader import AssetLoader +from metadrive.envs.scenario_env import ScenarioOnlineEnv +from metadrive.policy.idm_policy import TrajectoryIDMPolicy +from metadrive.policy.replay_policy import ReplayEgoCarPolicy +from metadrive.scenario.utils import read_dataset_summary, read_scenario_data +import pickle +import pathlib + +from metadrive.policy.replay_policy import ReplayEgoCarPolicy + + +@pytest.mark.parametrize("data_directory", ["waymo", "nuscenes"]) +def test_scenario_online_env(data_directory, render=False): + path = pathlib.Path(AssetLoader.file_path(AssetLoader.asset_path, data_directory, unix_style=False)) + summary, scenario_ids, mapping = read_dataset_summary(path) + try: + env = ScenarioOnlineEnv(config=dict( + use_render=render, + agent_policy=ReplayEgoCarPolicy, + )) + for file_name, file_path in mapping.items(): + full_path = path / file_path / file_name + assert full_path.exists(), f"{full_path} does not exist" + scenario_description = read_scenario_data(full_path) + print("Running scenario: ", scenario_description["id"]) + env.set_scenario(scenario_description) + env.reset() + for i in range(1000): + o, r, tm, tc, info = env.step([1.0, 0.]) + assert env.observation_space.contains(o) + if tm or tc: + assert info["arrive_dest"], "Can not arrive dest" + print("{} track_length: ".format(env.engine.global_seed), info["track_length"]) + # assert info["arrive_dest"], "Can not arrive dest" + break + + if i == 999: + raise ValueError("Can not arrive dest") + assert env.agent.panda_color == sns.color_palette("colorblind")[2] + finally: + env.close() + + +if __name__ == "__main__": + test_scenario_online_env("nuscenes", render=True) From af7e31558fb74cac15f12245863ec874691f7cd6 Mon Sep 17 00:00:00 2001 From: pengzhenghao Date: Thu, 5 Dec 2024 20:23:06 -0800 Subject: [PATCH 3/8] revert the bug causing tire-scale code --- metadrive/component/vehicle/base_vehicle.py | 31 ++------------------- 1 file changed, 2 insertions(+), 29 deletions(-) diff --git a/metadrive/component/vehicle/base_vehicle.py b/metadrive/component/vehicle/base_vehicle.py index 607289d69..7111c418b 100644 --- a/metadrive/component/vehicle/base_vehicle.py +++ b/metadrive/component/vehicle/base_vehicle.py @@ -145,15 +145,6 @@ def __init__( self.add_body(vehicle_chassis.getChassis()) self.system = vehicle_chassis self.chassis = self.origin - - if self.config["scale"] is not None: - w, l, h = self.config["scale"] - self.FRONT_WHEELBASE *= l - self.REAR_WHEELBASE *= l - self.LATERAL_TIRE_TO_CENTER *= w - self.TIRE_RADIUS *= h - self.CHASSIS_TO_WHEEL_AXIS *= h - self.wheels = self._create_wheel() # light experimental! @@ -635,7 +626,7 @@ def _create_vehicle_chassis(self): def _add_visualization(self): if self.render: - path, scale, offset, HPR = self.path + [path, scale, offset, HPR] = self.path should_update = (path not in BaseVehicle.model_collection) or (self.config["scale"] is not None) if should_update: @@ -659,10 +650,8 @@ def _add_visualization(self): car_model.setHpr(*HPR) car_model.setPos(offset[0], offset[1], offset[2] + extra_offset_z) BaseVehicle.model_collection[path] = car_model - else: car_model = BaseVehicle.model_collection[path] - car_model.instanceTo(self.origin) if self.config["random_color"]: material = Material() @@ -704,23 +693,7 @@ def _add_wheel(self, pos: Vec3, radius: float, front: bool, left): wheel_model = self.loader.loadModel(model_path) wheel_model.setTwoSided(self.TIRE_TWO_SIDED) wheel_model.reparentTo(wheel_np) - tire_scale = 1 * self.TIRE_MODEL_CORRECT if left else -1 * self.TIRE_MODEL_CORRECT - - if self.config['scale'] is not None: - tire_scale = ( - self.config['scale'][0] * tire_scale, self.config['scale'][1] * tire_scale, - self.config['scale'][2] * tire_scale - ) - - # A quick workaround here. - # The model position is set to height/2 in ScenarioMapManager. - # Now we set this offset to -height/2, so that the model will be placed on the ground. - # For the wheel, the bottom of it is not z=0, so we add two more terms to correct it. - extra_offset = -self.config["height"] / 2 + self.TIRE_RADIUS / self.config['scale'][ - 2] + self.CHASSIS_TO_WHEEL_AXIS / self.config['scale'][2] - wheel_model.setPos(0, 0, extra_offset) - - wheel_model.set_scale(tire_scale) + wheel_model.set_scale(1 * self.TIRE_MODEL_CORRECT if left else -1 * self.TIRE_MODEL_CORRECT) wheel = self.system.createWheel() wheel.setNode(wheel_np.node()) wheel.setChassisConnectionPointCs(pos) From 0bdb17e693370f35ba7cd5ae111071c93c882472 Mon Sep 17 00:00:00 2001 From: pengzhenghao Date: Thu, 5 Dec 2024 20:28:41 -0800 Subject: [PATCH 4/8] Hot fix the bug --- metadrive/manager/scenario_traffic_manager.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/metadrive/manager/scenario_traffic_manager.py b/metadrive/manager/scenario_traffic_manager.py index e43c5f831..325e1b66f 100644 --- a/metadrive/manager/scenario_traffic_manager.py +++ b/metadrive/manager/scenario_traffic_manager.py @@ -211,10 +211,11 @@ def spawn_vehicle(self, v_id, track): v_cfg["width"] = state["width"] v_cfg["length"] = state["length"] v_cfg["height"] = state["height"] - v_cfg["scale"] = ( - v_cfg["width"] / vehicle_class.DEFAULT_WIDTH, v_cfg["length"] / vehicle_class.DEFAULT_LENGTH, - v_cfg["height"] / vehicle_class.DEFAULT_HEIGHT - ) + if use_bounding_box: + v_cfg["scale"] = ( + v_cfg["width"] / vehicle_class.DEFAULT_WIDTH, v_cfg["length"] / vehicle_class.DEFAULT_LENGTH, + v_cfg["height"] / vehicle_class.DEFAULT_HEIGHT + ) if self.engine.global_config["top_down_show_real_size"]: v_cfg["top_down_length"] = track["state"]["length"][self.episode_step] From 44fea918bafe6040b6c37c29e6b73befb98b0a20 Mon Sep 17 00:00:00 2001 From: pengzhenghao Date: Thu, 5 Dec 2024 20:35:59 -0800 Subject: [PATCH 5/8] Allow to config "set_static", default to False to fix the bug of flickerring visualization. --- metadrive/envs/scenario_env.py | 11 +++++++++-- metadrive/policy/replay_policy.py | 9 ++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/metadrive/envs/scenario_env.py b/metadrive/envs/scenario_env.py index 1ba065862..fa35f178b 100644 --- a/metadrive/envs/scenario_env.py +++ b/metadrive/envs/scenario_env.py @@ -59,6 +59,12 @@ lane_line_detector=dict(num_lasers=0, distance=50), side_detector=dict(num_lasers=12, distance=50), ), + # If set_static=True, then the agent will not "fall from the sky". This will be helpful if you want to + # capture per-frame data for the agent (for example for collecting static sensor data). + # However, the physics engine will not update the position of the agent. So in the visualization, the image will be + # very chunky as the agent will not suddenly move to the next position for each step. + # Set to False for better visualization. + set_static=False, # ===== Reward Scheme ===== # See: https://github.com/metadriverse/metadrive/issues/283 @@ -410,7 +416,8 @@ def _reset_global_seed(self, force_seed=None): # "no_traffic":True, # "start_scenario_index": 192, # "start_scenario_index": 1000, - "num_scenarios": 30, + "num_scenarios": 3, + "set_static": True, # "force_reuse_object_name": True, # "data_directory": "/home/shady/Downloads/test_processed", "horizon": 1000, @@ -424,7 +431,7 @@ def _reset_global_seed(self, force_seed=None): lane_line_detector=dict(num_lasers=12, distance=50), side_detector=dict(num_lasers=160, distance=50) ), - "data_directory": AssetLoader.file_path("nuplan", unix_style=False), + "data_directory": AssetLoader.file_path("nuscenes", unix_style=False), } ) success = [] diff --git a/metadrive/policy/replay_policy.py b/metadrive/policy/replay_policy.py index 28ccce036..f66b3097d 100644 --- a/metadrive/policy/replay_policy.py +++ b/metadrive/policy/replay_policy.py @@ -63,7 +63,14 @@ def act(self, *args, **kwargs): self.control_object.set_velocity(info["velocity"], in_local_frame=self._velocity_local_frame) self.control_object.set_heading_theta(info["heading"]) self.control_object.set_angular_velocity(info["angular_velocity"]) - self.control_object.set_static(True) + + # If set_static, then the agent will not "fall from the sky". + # However, the physics engine will not update the position of the agent. + # So in the visualization, the image will be very chunky as the agent will not suddenly move to the next + # position for each step. + + if self.engine.global_config.get("set_static", False): + self.control_object.set_static(True) return None # Return None action so the base vehicle will not overwrite the steering & throttle From b8b81cdcb351d7ef5dc139e573c078d95198e3a3 Mon Sep 17 00:00:00 2001 From: pengzhenghao Date: Thu, 5 Dec 2024 20:40:48 -0800 Subject: [PATCH 6/8] add a warning --- metadrive/manager/scenario_data_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/metadrive/manager/scenario_data_manager.py b/metadrive/manager/scenario_data_manager.py index 779acc2bd..50e74514b 100644 --- a/metadrive/manager/scenario_data_manager.py +++ b/metadrive/manager/scenario_data_manager.py @@ -191,6 +191,7 @@ def set_scenario(self, scenario_description): self._scenario = scenario_description def get_scenario(self, seed=None, should_copy=False): + assert self._scenario is not None, "Please set scenario first via env.set_scenario(scenario_description)!" if should_copy: return copy.deepcopy(self._scenario) return self._scenario From efd163eee30e09114b1ca67c27412c708ba4c184 Mon Sep 17 00:00:00 2001 From: pengzhenghao Date: Thu, 5 Dec 2024 20:43:07 -0800 Subject: [PATCH 7/8] Introduce an example script --- metadrive/examples/run_scenario_online_env.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 metadrive/examples/run_scenario_online_env.py diff --git a/metadrive/examples/run_scenario_online_env.py b/metadrive/examples/run_scenario_online_env.py new file mode 100644 index 000000000..3c6eff8c8 --- /dev/null +++ b/metadrive/examples/run_scenario_online_env.py @@ -0,0 +1,46 @@ +""" +This script demonstrates how to run the scenario online env. The scenario online env is a special environment that +allows users to pass in scenario descriptions online. This is useful for using scenarios generated by external +online algorithms. + +In this script, we load the Waymo dataset and run the scenarios in the dataset. We use the +ReplayEgoCarPolicy as the agent policy. +""" +import pathlib + +import seaborn as sns + +from metadrive.engine.asset_loader import AssetLoader +from metadrive.envs.scenario_env import ScenarioOnlineEnv +from metadrive.policy.replay_policy import ReplayEgoCarPolicy +from metadrive.scenario.utils import read_dataset_summary, read_scenario_data + +if __name__ == "__main__": + data_directory = "waymo" + render = True + + path = pathlib.Path(AssetLoader.file_path(AssetLoader.asset_path, data_directory, unix_style=False)) + summary, scenario_ids, mapping = read_dataset_summary(path) + try: + env = ScenarioOnlineEnv(config=dict( + use_render=render, + agent_policy=ReplayEgoCarPolicy, + )) + for file_name, file_path in mapping.items(): + full_path = path / file_path / file_name + assert full_path.exists(), f"{full_path} does not exist" + scenario_description = read_scenario_data(full_path) + print("Running scenario: ", scenario_description["id"]) + env.set_scenario(scenario_description) + env.reset() + for i in range(1000): + o, r, tm, tc, info = env.step([1.0, 0.]) + assert env.observation_space.contains(o) + if tm or tc: + break + + if i == 999: + raise ValueError("Can not arrive dest") + assert env.agent.panda_color == sns.color_palette("colorblind")[2] + finally: + env.close() From 5526a9ec0da4f2198a50b982af495341bedf5702 Mon Sep 17 00:00:00 2001 From: pengzhenghao Date: Thu, 5 Dec 2024 20:47:49 -0800 Subject: [PATCH 8/8] Fix the map bug: in metadrive, all SD should be centralized. --- metadrive/manager/scenario_data_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/metadrive/manager/scenario_data_manager.py b/metadrive/manager/scenario_data_manager.py index 50e74514b..e9792743b 100644 --- a/metadrive/manager/scenario_data_manager.py +++ b/metadrive/manager/scenario_data_manager.py @@ -188,6 +188,7 @@ def current_scenario_summary(self): def set_scenario(self, scenario_description): SD.sanity_check(scenario_description) + scenario_description = SD.centralize_to_ego_car_initial_position(scenario_description) self._scenario = scenario_description def get_scenario(self, seed=None, should_copy=False):