Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce the ScenarioOnlineEnv #779

Merged
merged 9 commits into from
Dec 6, 2024
31 changes: 2 additions & 29 deletions metadrive/component/vehicle/base_vehicle.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,6 @@ def __init__(
self.add_body(vehicle_chassis.getChassis())
self.system = vehicle_chassis
self.chassis = self.origin

if self.config["scale"] is not None:
w, l, h = self.config["scale"]
self.FRONT_WHEELBASE *= l
self.REAR_WHEELBASE *= l
self.LATERAL_TIRE_TO_CENTER *= w
self.TIRE_RADIUS *= h
self.CHASSIS_TO_WHEEL_AXIS *= h

self.wheels = self._create_wheel()

# light experimental!
Expand Down Expand Up @@ -635,7 +626,7 @@ def _create_vehicle_chassis(self):

def _add_visualization(self):
if self.render:
path, scale, offset, HPR = self.path
[path, scale, offset, HPR] = self.path
should_update = (path not in BaseVehicle.model_collection) or (self.config["scale"] is not None)

if should_update:
Expand All @@ -659,10 +650,8 @@ def _add_visualization(self):
car_model.setHpr(*HPR)
car_model.setPos(offset[0], offset[1], offset[2] + extra_offset_z)
BaseVehicle.model_collection[path] = car_model

else:
car_model = BaseVehicle.model_collection[path]

car_model.instanceTo(self.origin)
if self.config["random_color"]:
material = Material()
Expand Down Expand Up @@ -704,23 +693,7 @@ def _add_wheel(self, pos: Vec3, radius: float, front: bool, left):
wheel_model = self.loader.loadModel(model_path)
wheel_model.setTwoSided(self.TIRE_TWO_SIDED)
wheel_model.reparentTo(wheel_np)
tire_scale = 1 * self.TIRE_MODEL_CORRECT if left else -1 * self.TIRE_MODEL_CORRECT

if self.config['scale'] is not None:
tire_scale = (
self.config['scale'][0] * tire_scale, self.config['scale'][1] * tire_scale,
self.config['scale'][2] * tire_scale
)

# A quick workaround here.
# The model position is set to height/2 in ScenarioMapManager.
# Now we set this offset to -height/2, so that the model will be placed on the ground.
# For the wheel, the bottom of it is not z=0, so we add two more terms to correct it.
extra_offset = -self.config["height"] / 2 + self.TIRE_RADIUS / self.config['scale'][
2] + self.CHASSIS_TO_WHEEL_AXIS / self.config['scale'][2]
wheel_model.setPos(0, 0, extra_offset)

wheel_model.set_scale(tire_scale)
wheel_model.set_scale(1 * self.TIRE_MODEL_CORRECT if left else -1 * self.TIRE_MODEL_CORRECT)
wheel = self.system.createWheel()
wheel.setNode(wheel_np.node())
wheel.setChassisConnectionPointCs(pos)
Expand Down
31 changes: 28 additions & 3 deletions metadrive/envs/scenario_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from metadrive.envs.base_env import BaseEnv
from metadrive.manager.scenario_agent_manager import ScenarioAgentManager
from metadrive.manager.scenario_curriculum_manager import ScenarioCurriculumManager
from metadrive.manager.scenario_data_manager import ScenarioDataManager
from metadrive.manager.scenario_data_manager import ScenarioDataManager, ScenarioOnlineDataManager
from metadrive.manager.scenario_light_manager import ScenarioLightManager
from metadrive.manager.scenario_map_manager import ScenarioMapManager
from metadrive.manager.scenario_traffic_manager import ScenarioTrafficManager
Expand Down Expand Up @@ -59,6 +59,12 @@
lane_line_detector=dict(num_lasers=0, distance=50),
side_detector=dict(num_lasers=12, distance=50),
),
# If set_static=True, then the agent will not "fall from the sky". This will be helpful if you want to
# capture per-frame data for the agent (for example for collecting static sensor data).
# However, the physics engine will not update the position of the agent. So in the visualization, the image will be
# very chunky as the agent will not suddenly move to the next position for each step.
# Set to False for better visualization.
set_static=False,

# ===== Reward Scheme =====
# See: https://github.com/metadriverse/metadrive/issues/283
Expand Down Expand Up @@ -393,6 +399,24 @@ def _reset_global_seed(self, force_seed=None):
self.seed(current_seed)


class ScenarioOnlineEnv(ScenarioEnv):
"""
This environment allow the user to pass in scenario data directly.
"""
def __init__(self, config=None):
super(ScenarioOnlineEnv, self).__init__(config)
self.lazy_init()

def setup_engine(self):
"""Overwrite the data_manager by ScenarioOnlineDataManager"""
super().setup_engine()
self.engine.update_manager("data_manager", ScenarioOnlineDataManager())

def set_scenario(self, scenario_data):
"""Please call this function before env.reset()"""
self.engine.data_manager.set_scenario(scenario_data)


if __name__ == "__main__":
env = ScenarioEnv(
{
Expand All @@ -410,7 +434,8 @@ def _reset_global_seed(self, force_seed=None):
# "no_traffic":True,
# "start_scenario_index": 192,
# "start_scenario_index": 1000,
"num_scenarios": 30,
"num_scenarios": 3,
"set_static": True,
# "force_reuse_object_name": True,
# "data_directory": "/home/shady/Downloads/test_processed",
"horizon": 1000,
Expand All @@ -424,7 +449,7 @@ def _reset_global_seed(self, force_seed=None):
lane_line_detector=dict(num_lasers=12, distance=50),
side_detector=dict(num_lasers=160, distance=50)
),
"data_directory": AssetLoader.file_path("nuplan", unix_style=False),
"data_directory": AssetLoader.file_path("nuscenes", unix_style=False),
}
)
success = []
Expand Down
46 changes: 46 additions & 0 deletions metadrive/examples/run_scenario_online_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
This script demonstrates how to run the scenario online env. The scenario online env is a special environment that
allows users to pass in scenario descriptions online. This is useful for using scenarios generated by external
online algorithms.

In this script, we load the Waymo dataset and run the scenarios in the dataset. We use the
ReplayEgoCarPolicy as the agent policy.
"""
import pathlib

import seaborn as sns

from metadrive.engine.asset_loader import AssetLoader
from metadrive.envs.scenario_env import ScenarioOnlineEnv
from metadrive.policy.replay_policy import ReplayEgoCarPolicy
from metadrive.scenario.utils import read_dataset_summary, read_scenario_data

if __name__ == "__main__":
data_directory = "waymo"
render = True

path = pathlib.Path(AssetLoader.file_path(AssetLoader.asset_path, data_directory, unix_style=False))
summary, scenario_ids, mapping = read_dataset_summary(path)
try:
env = ScenarioOnlineEnv(config=dict(
use_render=render,
agent_policy=ReplayEgoCarPolicy,
))
for file_name, file_path in mapping.items():
full_path = path / file_path / file_name
assert full_path.exists(), f"{full_path} does not exist"
scenario_description = read_scenario_data(full_path)
print("Running scenario: ", scenario_description["id"])
env.set_scenario(scenario_description)
env.reset()
for i in range(1000):
o, r, tm, tc, info = env.step([1.0, 0.])
assert env.observation_space.contains(o)
if tm or tc:
break

if i == 999:
raise ValueError("Can not arrive dest")
assert env.agent.panda_color == sns.color_palette("colorblind")[2]
finally:
env.close()
58 changes: 58 additions & 0 deletions metadrive/manager/scenario_data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,61 @@ def destroy(self):
self.summary_lookup.clear()
self.mapping.clear()
self.summary_dict, self.summary_lookup, self.mapping = None, None, None


class ScenarioOnlineDataManager(BaseManager):
"""
Compared to ScenarioDataManager, this manager allow user to pass in Scenario Description online.
It will not read data from disk, but receive data from user.
"""
PRIORITY = -10
_scenario = None

@property
def current_scenario_summary(self):
return self.current_scenario[SD.METADATA]

def set_scenario(self, scenario_description):
SD.sanity_check(scenario_description)
scenario_description = SD.centralize_to_ego_car_initial_position(scenario_description)
self._scenario = scenario_description

def get_scenario(self, seed=None, should_copy=False):
assert self._scenario is not None, "Please set scenario first via env.set_scenario(scenario_description)!"
if should_copy:
return copy.deepcopy(self._scenario)
return self._scenario

def get_metadata(self):
raise ValueError()
state = super(ScenarioDataManager, self).get_metadata()
raw_data = self.current_scenario
state["raw_data"] = raw_data
return state

@property
def current_scenario_length(self):
return self.current_scenario[SD.LENGTH]

@property
def current_scenario(self):
return self._scenario

@property
def current_scenario_difficulty(self):
return 0

@property
def current_scenario_id(self):
return self.current_scenario_summary["scenario_id"]

@property
def data_coverage(self):
return None

def destroy(self):
"""
Clear memory
"""
super(ScenarioOnlineDataManager, self).destroy()
self._scenario = None
9 changes: 5 additions & 4 deletions metadrive/manager/scenario_traffic_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,11 @@ def spawn_vehicle(self, v_id, track):
v_cfg["width"] = state["width"]
v_cfg["length"] = state["length"]
v_cfg["height"] = state["height"]
v_cfg["scale"] = (
v_cfg["width"] / vehicle_class.DEFAULT_WIDTH, v_cfg["length"] / vehicle_class.DEFAULT_LENGTH,
v_cfg["height"] / vehicle_class.DEFAULT_HEIGHT
)
if use_bounding_box:
v_cfg["scale"] = (
v_cfg["width"] / vehicle_class.DEFAULT_WIDTH, v_cfg["length"] / vehicle_class.DEFAULT_LENGTH,
v_cfg["height"] / vehicle_class.DEFAULT_HEIGHT
)

if self.engine.global_config["top_down_show_real_size"]:
v_cfg["top_down_length"] = track["state"]["length"][self.episode_step]
Expand Down
9 changes: 8 additions & 1 deletion metadrive/policy/replay_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,14 @@ def act(self, *args, **kwargs):
self.control_object.set_velocity(info["velocity"], in_local_frame=self._velocity_local_frame)
self.control_object.set_heading_theta(info["heading"])
self.control_object.set_angular_velocity(info["angular_velocity"])
self.control_object.set_static(True)

# If set_static, then the agent will not "fall from the sky".
# However, the physics engine will not update the position of the agent.
# So in the visualization, the image will be very chunky as the agent will not suddenly move to the next
# position for each step.

if self.engine.global_config.get("set_static", False):
self.control_object.set_static(True)

return None # Return None action so the base vehicle will not overwrite the steering & throttle

Expand Down
49 changes: 49 additions & 0 deletions metadrive/tests/test_env/test_scenario_online_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
import seaborn as sns
import numpy as np

from metadrive.engine.asset_loader import AssetLoader
from metadrive.envs.scenario_env import ScenarioOnlineEnv
from metadrive.policy.idm_policy import TrajectoryIDMPolicy
from metadrive.policy.replay_policy import ReplayEgoCarPolicy
from metadrive.scenario.utils import read_dataset_summary, read_scenario_data
import pickle
import pathlib

from metadrive.policy.replay_policy import ReplayEgoCarPolicy


@pytest.mark.parametrize("data_directory", ["waymo", "nuscenes"])
def test_scenario_online_env(data_directory, render=False):
path = pathlib.Path(AssetLoader.file_path(AssetLoader.asset_path, data_directory, unix_style=False))
summary, scenario_ids, mapping = read_dataset_summary(path)
try:
env = ScenarioOnlineEnv(config=dict(
use_render=render,
agent_policy=ReplayEgoCarPolicy,
))
for file_name, file_path in mapping.items():
full_path = path / file_path / file_name
assert full_path.exists(), f"{full_path} does not exist"
scenario_description = read_scenario_data(full_path)
print("Running scenario: ", scenario_description["id"])
env.set_scenario(scenario_description)
env.reset()
for i in range(1000):
o, r, tm, tc, info = env.step([1.0, 0.])
assert env.observation_space.contains(o)
if tm or tc:
assert info["arrive_dest"], "Can not arrive dest"
print("{} track_length: ".format(env.engine.global_seed), info["track_length"])
# assert info["arrive_dest"], "Can not arrive dest"
break

if i == 999:
raise ValueError("Can not arrive dest")
assert env.agent.panda_color == sns.color_palette("colorblind")[2]
finally:
env.close()


if __name__ == "__main__":
test_scenario_online_env("nuscenes", render=True)
Loading