Skip to content

Commit

Permalink
Introduce the ScenarioOnlineEnv (#779)
Browse files Browse the repository at this point in the history
* Introduce ScenarioOnlineDataManager and ScenarioOnlineEnv

* Prepare the ScenarioOnlineEnv

* revert the bug causing tire-scale code

* Hot fix the bug

* Allow to config "set_static", default to False to fix the bug of flickerring visualization.

* add a warning

* Introduce an example script

* Fix the map bug: in metadrive, all SD should be centralized.
  • Loading branch information
pengzhenghao authored Dec 6, 2024
1 parent 6b19a02 commit 5bf8ea8
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 1 deletion.
20 changes: 19 additions & 1 deletion metadrive/envs/scenario_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from metadrive.envs.base_env import BaseEnv
from metadrive.manager.scenario_agent_manager import ScenarioAgentManager
from metadrive.manager.scenario_curriculum_manager import ScenarioCurriculumManager
from metadrive.manager.scenario_data_manager import ScenarioDataManager
from metadrive.manager.scenario_data_manager import ScenarioDataManager, ScenarioOnlineDataManager
from metadrive.manager.scenario_light_manager import ScenarioLightManager
from metadrive.manager.scenario_map_manager import ScenarioMapManager
from metadrive.manager.scenario_traffic_manager import ScenarioTrafficManager
Expand Down Expand Up @@ -399,6 +399,24 @@ def _reset_global_seed(self, force_seed=None):
self.seed(current_seed)


class ScenarioOnlineEnv(ScenarioEnv):
"""
This environment allow the user to pass in scenario data directly.
"""
def __init__(self, config=None):
super(ScenarioOnlineEnv, self).__init__(config)
self.lazy_init()

def setup_engine(self):
"""Overwrite the data_manager by ScenarioOnlineDataManager"""
super().setup_engine()
self.engine.update_manager("data_manager", ScenarioOnlineDataManager())

def set_scenario(self, scenario_data):
"""Please call this function before env.reset()"""
self.engine.data_manager.set_scenario(scenario_data)


if __name__ == "__main__":
env = ScenarioEnv(
{
Expand Down
46 changes: 46 additions & 0 deletions metadrive/examples/run_scenario_online_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
This script demonstrates how to run the scenario online env. The scenario online env is a special environment that
allows users to pass in scenario descriptions online. This is useful for using scenarios generated by external
online algorithms.
In this script, we load the Waymo dataset and run the scenarios in the dataset. We use the
ReplayEgoCarPolicy as the agent policy.
"""
import pathlib

import seaborn as sns

from metadrive.engine.asset_loader import AssetLoader
from metadrive.envs.scenario_env import ScenarioOnlineEnv
from metadrive.policy.replay_policy import ReplayEgoCarPolicy
from metadrive.scenario.utils import read_dataset_summary, read_scenario_data

if __name__ == "__main__":
data_directory = "waymo"
render = True

path = pathlib.Path(AssetLoader.file_path(AssetLoader.asset_path, data_directory, unix_style=False))
summary, scenario_ids, mapping = read_dataset_summary(path)
try:
env = ScenarioOnlineEnv(config=dict(
use_render=render,
agent_policy=ReplayEgoCarPolicy,
))
for file_name, file_path in mapping.items():
full_path = path / file_path / file_name
assert full_path.exists(), f"{full_path} does not exist"
scenario_description = read_scenario_data(full_path)
print("Running scenario: ", scenario_description["id"])
env.set_scenario(scenario_description)
env.reset()
for i in range(1000):
o, r, tm, tc, info = env.step([1.0, 0.])
assert env.observation_space.contains(o)
if tm or tc:
break

if i == 999:
raise ValueError("Can not arrive dest")
assert env.agent.panda_color == sns.color_palette("colorblind")[2]
finally:
env.close()
58 changes: 58 additions & 0 deletions metadrive/manager/scenario_data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,61 @@ def destroy(self):
self.summary_lookup.clear()
self.mapping.clear()
self.summary_dict, self.summary_lookup, self.mapping = None, None, None


class ScenarioOnlineDataManager(BaseManager):
"""
Compared to ScenarioDataManager, this manager allow user to pass in Scenario Description online.
It will not read data from disk, but receive data from user.
"""
PRIORITY = -10
_scenario = None

@property
def current_scenario_summary(self):
return self.current_scenario[SD.METADATA]

def set_scenario(self, scenario_description):
SD.sanity_check(scenario_description)
scenario_description = SD.centralize_to_ego_car_initial_position(scenario_description)
self._scenario = scenario_description

def get_scenario(self, seed=None, should_copy=False):
assert self._scenario is not None, "Please set scenario first via env.set_scenario(scenario_description)!"
if should_copy:
return copy.deepcopy(self._scenario)
return self._scenario

def get_metadata(self):
raise ValueError()
state = super(ScenarioDataManager, self).get_metadata()
raw_data = self.current_scenario
state["raw_data"] = raw_data
return state

@property
def current_scenario_length(self):
return self.current_scenario[SD.LENGTH]

@property
def current_scenario(self):
return self._scenario

@property
def current_scenario_difficulty(self):
return 0

@property
def current_scenario_id(self):
return self.current_scenario_summary["scenario_id"]

@property
def data_coverage(self):
return None

def destroy(self):
"""
Clear memory
"""
super(ScenarioOnlineDataManager, self).destroy()
self._scenario = None
49 changes: 49 additions & 0 deletions metadrive/tests/test_env/test_scenario_online_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
import seaborn as sns
import numpy as np

from metadrive.engine.asset_loader import AssetLoader
from metadrive.envs.scenario_env import ScenarioOnlineEnv
from metadrive.policy.idm_policy import TrajectoryIDMPolicy
from metadrive.policy.replay_policy import ReplayEgoCarPolicy
from metadrive.scenario.utils import read_dataset_summary, read_scenario_data
import pickle
import pathlib

from metadrive.policy.replay_policy import ReplayEgoCarPolicy


@pytest.mark.parametrize("data_directory", ["waymo", "nuscenes"])
def test_scenario_online_env(data_directory, render=False):
path = pathlib.Path(AssetLoader.file_path(AssetLoader.asset_path, data_directory, unix_style=False))
summary, scenario_ids, mapping = read_dataset_summary(path)
try:
env = ScenarioOnlineEnv(config=dict(
use_render=render,
agent_policy=ReplayEgoCarPolicy,
))
for file_name, file_path in mapping.items():
full_path = path / file_path / file_name
assert full_path.exists(), f"{full_path} does not exist"
scenario_description = read_scenario_data(full_path)
print("Running scenario: ", scenario_description["id"])
env.set_scenario(scenario_description)
env.reset()
for i in range(1000):
o, r, tm, tc, info = env.step([1.0, 0.])
assert env.observation_space.contains(o)
if tm or tc:
assert info["arrive_dest"], "Can not arrive dest"
print("{} track_length: ".format(env.engine.global_seed), info["track_length"])
# assert info["arrive_dest"], "Can not arrive dest"
break

if i == 999:
raise ValueError("Can not arrive dest")
assert env.agent.panda_color == sns.color_palette("colorblind")[2]
finally:
env.close()


if __name__ == "__main__":
test_scenario_online_env("nuscenes", render=True)

0 comments on commit 5bf8ea8

Please sign in to comment.