-
Notifications
You must be signed in to change notification settings - Fork 110
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce the ScenarioOnlineEnv (#779)
* Introduce ScenarioOnlineDataManager and ScenarioOnlineEnv * Prepare the ScenarioOnlineEnv * revert the bug causing tire-scale code * Hot fix the bug * Allow to config "set_static", default to False to fix the bug of flickerring visualization. * add a warning * Introduce an example script * Fix the map bug: in metadrive, all SD should be centralized.
- Loading branch information
1 parent
6b19a02
commit 5bf8ea8
Showing
4 changed files
with
172 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
""" | ||
This script demonstrates how to run the scenario online env. The scenario online env is a special environment that | ||
allows users to pass in scenario descriptions online. This is useful for using scenarios generated by external | ||
online algorithms. | ||
In this script, we load the Waymo dataset and run the scenarios in the dataset. We use the | ||
ReplayEgoCarPolicy as the agent policy. | ||
""" | ||
import pathlib | ||
|
||
import seaborn as sns | ||
|
||
from metadrive.engine.asset_loader import AssetLoader | ||
from metadrive.envs.scenario_env import ScenarioOnlineEnv | ||
from metadrive.policy.replay_policy import ReplayEgoCarPolicy | ||
from metadrive.scenario.utils import read_dataset_summary, read_scenario_data | ||
|
||
if __name__ == "__main__": | ||
data_directory = "waymo" | ||
render = True | ||
|
||
path = pathlib.Path(AssetLoader.file_path(AssetLoader.asset_path, data_directory, unix_style=False)) | ||
summary, scenario_ids, mapping = read_dataset_summary(path) | ||
try: | ||
env = ScenarioOnlineEnv(config=dict( | ||
use_render=render, | ||
agent_policy=ReplayEgoCarPolicy, | ||
)) | ||
for file_name, file_path in mapping.items(): | ||
full_path = path / file_path / file_name | ||
assert full_path.exists(), f"{full_path} does not exist" | ||
scenario_description = read_scenario_data(full_path) | ||
print("Running scenario: ", scenario_description["id"]) | ||
env.set_scenario(scenario_description) | ||
env.reset() | ||
for i in range(1000): | ||
o, r, tm, tc, info = env.step([1.0, 0.]) | ||
assert env.observation_space.contains(o) | ||
if tm or tc: | ||
break | ||
|
||
if i == 999: | ||
raise ValueError("Can not arrive dest") | ||
assert env.agent.panda_color == sns.color_palette("colorblind")[2] | ||
finally: | ||
env.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import pytest | ||
import seaborn as sns | ||
import numpy as np | ||
|
||
from metadrive.engine.asset_loader import AssetLoader | ||
from metadrive.envs.scenario_env import ScenarioOnlineEnv | ||
from metadrive.policy.idm_policy import TrajectoryIDMPolicy | ||
from metadrive.policy.replay_policy import ReplayEgoCarPolicy | ||
from metadrive.scenario.utils import read_dataset_summary, read_scenario_data | ||
import pickle | ||
import pathlib | ||
|
||
from metadrive.policy.replay_policy import ReplayEgoCarPolicy | ||
|
||
|
||
@pytest.mark.parametrize("data_directory", ["waymo", "nuscenes"]) | ||
def test_scenario_online_env(data_directory, render=False): | ||
path = pathlib.Path(AssetLoader.file_path(AssetLoader.asset_path, data_directory, unix_style=False)) | ||
summary, scenario_ids, mapping = read_dataset_summary(path) | ||
try: | ||
env = ScenarioOnlineEnv(config=dict( | ||
use_render=render, | ||
agent_policy=ReplayEgoCarPolicy, | ||
)) | ||
for file_name, file_path in mapping.items(): | ||
full_path = path / file_path / file_name | ||
assert full_path.exists(), f"{full_path} does not exist" | ||
scenario_description = read_scenario_data(full_path) | ||
print("Running scenario: ", scenario_description["id"]) | ||
env.set_scenario(scenario_description) | ||
env.reset() | ||
for i in range(1000): | ||
o, r, tm, tc, info = env.step([1.0, 0.]) | ||
assert env.observation_space.contains(o) | ||
if tm or tc: | ||
assert info["arrive_dest"], "Can not arrive dest" | ||
print("{} track_length: ".format(env.engine.global_seed), info["track_length"]) | ||
# assert info["arrive_dest"], "Can not arrive dest" | ||
break | ||
|
||
if i == 999: | ||
raise ValueError("Can not arrive dest") | ||
assert env.agent.panda_color == sns.color_palette("colorblind")[2] | ||
finally: | ||
env.close() | ||
|
||
|
||
if __name__ == "__main__": | ||
test_scenario_online_env("nuscenes", render=True) |