From 5bf8ea8909c4643a4099a250e6f5fb89c695d8b4 Mon Sep 17 00:00:00 2001 From: Zhenghao Peng Date: Thu, 5 Dec 2024 21:32:00 -0800 Subject: [PATCH] Introduce the ScenarioOnlineEnv (#779) * Introduce ScenarioOnlineDataManager and ScenarioOnlineEnv * Prepare the ScenarioOnlineEnv * revert the bug causing tire-scale code * Hot fix the bug * Allow to config "set_static", default to False to fix the bug of flickerring visualization. * add a warning * Introduce an example script * Fix the map bug: in metadrive, all SD should be centralized. --- metadrive/envs/scenario_env.py | 20 ++++++- metadrive/examples/run_scenario_online_env.py | 46 +++++++++++++++ metadrive/manager/scenario_data_manager.py | 58 +++++++++++++++++++ .../test_env/test_scenario_online_env.py | 49 ++++++++++++++++ 4 files changed, 172 insertions(+), 1 deletion(-) create mode 100644 metadrive/examples/run_scenario_online_env.py create mode 100644 metadrive/tests/test_env/test_scenario_online_env.py diff --git a/metadrive/envs/scenario_env.py b/metadrive/envs/scenario_env.py index fa35f178b..a8837d587 100644 --- a/metadrive/envs/scenario_env.py +++ b/metadrive/envs/scenario_env.py @@ -10,7 +10,7 @@ from metadrive.envs.base_env import BaseEnv from metadrive.manager.scenario_agent_manager import ScenarioAgentManager from metadrive.manager.scenario_curriculum_manager import ScenarioCurriculumManager -from metadrive.manager.scenario_data_manager import ScenarioDataManager +from metadrive.manager.scenario_data_manager import ScenarioDataManager, ScenarioOnlineDataManager from metadrive.manager.scenario_light_manager import ScenarioLightManager from metadrive.manager.scenario_map_manager import ScenarioMapManager from metadrive.manager.scenario_traffic_manager import ScenarioTrafficManager @@ -399,6 +399,24 @@ def _reset_global_seed(self, force_seed=None): self.seed(current_seed) +class ScenarioOnlineEnv(ScenarioEnv): + """ + This environment allow the user to pass in scenario data directly. + """ + def __init__(self, config=None): + super(ScenarioOnlineEnv, self).__init__(config) + self.lazy_init() + + def setup_engine(self): + """Overwrite the data_manager by ScenarioOnlineDataManager""" + super().setup_engine() + self.engine.update_manager("data_manager", ScenarioOnlineDataManager()) + + def set_scenario(self, scenario_data): + """Please call this function before env.reset()""" + self.engine.data_manager.set_scenario(scenario_data) + + if __name__ == "__main__": env = ScenarioEnv( { diff --git a/metadrive/examples/run_scenario_online_env.py b/metadrive/examples/run_scenario_online_env.py new file mode 100644 index 000000000..3c6eff8c8 --- /dev/null +++ b/metadrive/examples/run_scenario_online_env.py @@ -0,0 +1,46 @@ +""" +This script demonstrates how to run the scenario online env. The scenario online env is a special environment that +allows users to pass in scenario descriptions online. This is useful for using scenarios generated by external +online algorithms. + +In this script, we load the Waymo dataset and run the scenarios in the dataset. We use the +ReplayEgoCarPolicy as the agent policy. +""" +import pathlib + +import seaborn as sns + +from metadrive.engine.asset_loader import AssetLoader +from metadrive.envs.scenario_env import ScenarioOnlineEnv +from metadrive.policy.replay_policy import ReplayEgoCarPolicy +from metadrive.scenario.utils import read_dataset_summary, read_scenario_data + +if __name__ == "__main__": + data_directory = "waymo" + render = True + + path = pathlib.Path(AssetLoader.file_path(AssetLoader.asset_path, data_directory, unix_style=False)) + summary, scenario_ids, mapping = read_dataset_summary(path) + try: + env = ScenarioOnlineEnv(config=dict( + use_render=render, + agent_policy=ReplayEgoCarPolicy, + )) + for file_name, file_path in mapping.items(): + full_path = path / file_path / file_name + assert full_path.exists(), f"{full_path} does not exist" + scenario_description = read_scenario_data(full_path) + print("Running scenario: ", scenario_description["id"]) + env.set_scenario(scenario_description) + env.reset() + for i in range(1000): + o, r, tm, tc, info = env.step([1.0, 0.]) + assert env.observation_space.contains(o) + if tm or tc: + break + + if i == 999: + raise ValueError("Can not arrive dest") + assert env.agent.panda_color == sns.color_palette("colorblind")[2] + finally: + env.close() diff --git a/metadrive/manager/scenario_data_manager.py b/metadrive/manager/scenario_data_manager.py index 8b785efce..e9792743b 100644 --- a/metadrive/manager/scenario_data_manager.py +++ b/metadrive/manager/scenario_data_manager.py @@ -172,3 +172,61 @@ def destroy(self): self.summary_lookup.clear() self.mapping.clear() self.summary_dict, self.summary_lookup, self.mapping = None, None, None + + +class ScenarioOnlineDataManager(BaseManager): + """ + Compared to ScenarioDataManager, this manager allow user to pass in Scenario Description online. + It will not read data from disk, but receive data from user. + """ + PRIORITY = -10 + _scenario = None + + @property + def current_scenario_summary(self): + return self.current_scenario[SD.METADATA] + + def set_scenario(self, scenario_description): + SD.sanity_check(scenario_description) + scenario_description = SD.centralize_to_ego_car_initial_position(scenario_description) + self._scenario = scenario_description + + def get_scenario(self, seed=None, should_copy=False): + assert self._scenario is not None, "Please set scenario first via env.set_scenario(scenario_description)!" + if should_copy: + return copy.deepcopy(self._scenario) + return self._scenario + + def get_metadata(self): + raise ValueError() + state = super(ScenarioDataManager, self).get_metadata() + raw_data = self.current_scenario + state["raw_data"] = raw_data + return state + + @property + def current_scenario_length(self): + return self.current_scenario[SD.LENGTH] + + @property + def current_scenario(self): + return self._scenario + + @property + def current_scenario_difficulty(self): + return 0 + + @property + def current_scenario_id(self): + return self.current_scenario_summary["scenario_id"] + + @property + def data_coverage(self): + return None + + def destroy(self): + """ + Clear memory + """ + super(ScenarioOnlineDataManager, self).destroy() + self._scenario = None diff --git a/metadrive/tests/test_env/test_scenario_online_env.py b/metadrive/tests/test_env/test_scenario_online_env.py new file mode 100644 index 000000000..ea2914946 --- /dev/null +++ b/metadrive/tests/test_env/test_scenario_online_env.py @@ -0,0 +1,49 @@ +import pytest +import seaborn as sns +import numpy as np + +from metadrive.engine.asset_loader import AssetLoader +from metadrive.envs.scenario_env import ScenarioOnlineEnv +from metadrive.policy.idm_policy import TrajectoryIDMPolicy +from metadrive.policy.replay_policy import ReplayEgoCarPolicy +from metadrive.scenario.utils import read_dataset_summary, read_scenario_data +import pickle +import pathlib + +from metadrive.policy.replay_policy import ReplayEgoCarPolicy + + +@pytest.mark.parametrize("data_directory", ["waymo", "nuscenes"]) +def test_scenario_online_env(data_directory, render=False): + path = pathlib.Path(AssetLoader.file_path(AssetLoader.asset_path, data_directory, unix_style=False)) + summary, scenario_ids, mapping = read_dataset_summary(path) + try: + env = ScenarioOnlineEnv(config=dict( + use_render=render, + agent_policy=ReplayEgoCarPolicy, + )) + for file_name, file_path in mapping.items(): + full_path = path / file_path / file_name + assert full_path.exists(), f"{full_path} does not exist" + scenario_description = read_scenario_data(full_path) + print("Running scenario: ", scenario_description["id"]) + env.set_scenario(scenario_description) + env.reset() + for i in range(1000): + o, r, tm, tc, info = env.step([1.0, 0.]) + assert env.observation_space.contains(o) + if tm or tc: + assert info["arrive_dest"], "Can not arrive dest" + print("{} track_length: ".format(env.engine.global_seed), info["track_length"]) + # assert info["arrive_dest"], "Can not arrive dest" + break + + if i == 999: + raise ValueError("Can not arrive dest") + assert env.agent.panda_color == sns.color_palette("colorblind")[2] + finally: + env.close() + + +if __name__ == "__main__": + test_scenario_online_env("nuscenes", render=True)