From 9b3fde1b3ebd9382d30de7ac4c60ce868f681a53 Mon Sep 17 00:00:00 2001 From: Zhenghao Peng Date: Tue, 20 Aug 2024 15:44:24 -0700 Subject: [PATCH] Introducing a single agent, multi-goal, four-ways intersection environment (#699) * better way to handle U turn in Intersection * INIT * Revert "better way to handle U turn in Intersection" This reverts commit d9bf4bb4233a3c4aac817a1f6d112ff1d0b9ba91. * add doc * fix bug * add docs * If navigation_module is set to None, remove navigation in StateObservation. (I am not 100% sure if this is OK. Might cause error.) * WIP: Implementing a navigation/goal manager * Implemented goal / navigation managers * adding a U-turn intersection block * add assert in node_road_network * fix potential bug * Now the environment is running correctly! (but reset is broken) * introduce observation * format * add more docs * introduce a config to disable navigation arrows * minor * use varying dynamics agent in multi-goal env * support use a list of types to control PG map; allow set map=None to disable shortcut config map. * allow change radius of the intersection; set accident_prob=1.0 * 1) introduce goal-agnoistic info["obs/ego/*"], 2) rename keys in info dict * optimize the reward structure, now: ===== timestep 220 ===== route completion: route_completion/goals/default: 0.92 route_completion/goals/go_straight: 0.50 route_completion/goals/left_turn: 0.46 route_completion/goals/right_turn: 0.92 route_completion/goals/u_turn: 0.53 reward: reward/default_reward: 1.27 reward/goal_agnostic_reward: 0.05 reward/goals/default: 1.27 reward/goals/go_straight: 0.14 reward/goals/left_turn: 0.06 reward/goals/right_turn: 1.27 reward/goals/u_turn: 0.01 ======================= * to follow setting, use lane_num=1 * format, ready to launch SB3 td3 * Remove varying dynamics * lane_num=1 * allow do more visualization * add default arrive_dest * add some comments * now we return full observation for different goals in info["obs/goals/xxx"] * [DANGER] allow to generate sidewalk for "negative road". Not sure the affect of this commit in other cases. Might need further check. * Add SIDEWALK to the side detector & the lane line detector. * use 240line for sidedetector, remove vehicle/lane detector * fix a bug * Fix a severe bug that messes up observation * introduce a penalty for wrong way * When draw the line to next checkpoint, also draw the line from next ckpt to next next ckpt. * Add crash_sidewalk_penalty for MetaDrive env, default crash_sidewalk_penalty=0 * Set on_continuous_line_done=False for multigoal env * Add config "out_of_road_done" for MDEnv * Remove U turn * Set out_of_road_done=False * Add penalty for out_of_route (this might be helpful in multigoal setting) * Change reward scheme * Change radius to 12 * remove goal_agnostic_reward * add GOAL_DEPENDENT_STATE * up * change obs * Add some randomness in map * enable U turn * WIP: Now support conventional RL env * fix bug * Fix a bug and use Customize Observation * Better handle lidars' configs * Remove those hyper diff from MetaDriveEnv * fix * Fix * format * remove a file * Add an example notebook for multigoalintersection * Setup FFpmeg in CI to support video gen in docs * minor * minor * Add docs * Fix test * format * try fix ffmpeg * Fix test --- .github/workflows/main.yml | 21 + documentation/source/index.rst | 1 + .../source/multigoal_intersection.ipynb | 272 ++++++++ metadrive/component/algorithm/BIG.py | 21 +- .../component/algorithm/blocks_prob_dist.py | 3 +- metadrive/component/map/pg_map.py | 2 +- .../navigation_module/base_navigation.py | 19 +- .../edge_network_navigation.py | 18 +- .../node_network_navigation.py | 29 +- metadrive/component/pgblock/intersection.py | 18 +- metadrive/component/pgblock/pg_block.py | 11 +- .../component/pgblock/std_intersection.py | 9 +- .../road_network/node_road_network.py | 1 + .../component/sensors/distance_detector.py | 5 +- metadrive/component/sensors/lidar.py | 1 - metadrive/component/vehicle/base_vehicle.py | 49 +- metadrive/engine/interface.py | 2 + metadrive/envs/base_env.py | 2 + metadrive/envs/metadrive_env.py | 7 +- metadrive/envs/multigoal_intersection.py | 616 ++++++++++++++++++ .../train_generalization_experiment.py | 7 +- metadrive/manager/traffic_manager.py | 18 +- metadrive/obs/state_obs.py | 13 +- .../test_component/test_lane_line_detector.py | 167 +---- metadrive/utils/registry.py | 4 +- setup.py | 1 + 26 files changed, 1110 insertions(+), 207 deletions(-) create mode 100644 documentation/source/multigoal_intersection.ipynb create mode 100644 metadrive/envs/multigoal_intersection.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 09e0b23ca..491b0b01a 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -280,10 +280,31 @@ jobs: run: | sudo apt-get -y install xvfb sudo /usr/bin/Xvfb :0 -screen 0 1280x1024x24 & + - name: Setup FFmpeg + uses: FedericoCarboni/setup-ffmpeg@v3 + id: setup-ffmpeg + with: + # A specific version to download, may also be "release" or a specific version + # like "6.1.0". At the moment semver specifiers (i.e. >=6.1.0) are supported + # only on Windows, on other platforms they are allowed but version is matched + # exactly regardless. + ffmpeg-version: release + # Target architecture of the ffmpeg executable to install. Defaults to the + # system architecture. Only x64 and arm64 are supported (arm64 only on Linux). + architecture: '' + # Linking type of the binaries. Use "shared" to download shared binaries and + # "static" for statically linked ones. Shared builds are currently only available + # for windows releases. Defaults to "static" + linking-type: static + # As of version 3 of this action, builds are no longer downloaded from GitHub + # except on Windows: https://github.com/GyanD/codexffmpeg/releases. + github-token: ${{ github.server_url == 'https://github.com' && github.token || '' }} - name: Blackbox tests run: | pip install cython pip install numpy + pip install mediapy + conda install ffmpeg pip install -e . pip install -e .[gym] python -m metadrive.pull_asset diff --git a/documentation/source/index.rst b/documentation/source/index.rst index c5c618c3c..69e48e956 100644 --- a/documentation/source/index.rst +++ b/documentation/source/index.rst @@ -47,6 +47,7 @@ Please feel free to contact us if you have any suggestions or ideas! action.ipynb reward_cost_done.ipynb training.ipynb + multigoal_intersection.ipynb .. toctree:: :hidden: diff --git a/documentation/source/multigoal_intersection.ipynb b/documentation/source/multigoal_intersection.ipynb new file mode 100644 index 000000000..e0224cb8b --- /dev/null +++ b/documentation/source/multigoal_intersection.ipynb @@ -0,0 +1,272 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "2832faf1-1bd3-4a95-8b0d-b3289e74d4d0", + "metadata": {}, + "source": [ + "# Demonstration on MultigoalIntersection\n", + "\n", + "In this notebook, we demonstrate how to setup a multigoal intersection environment where you can access relevant stats (e.g. route completion, reward, success rate) for all four possible goals (right turn, left turn, move forward, U turn) simultaneously.\n", + "\n", + "We demonstrate how to build the environment, in which we have successfully trained a SAC expert that achieves 99% success rate, and how to access those stats in the info dict returned each step.\n", + "\n", + "*Note: We pretrain the SAC expert with `use_multigoal_intersection=False` and then finetune it with `use_multigoal_intersection=True`.*" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "b9733eac-9d07-47cf-bda7-4dbb8d5f2412", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from metadrive.envs.gym_wrapper import create_gym_wrapper\n", + "from metadrive.envs.multigoal_intersection import MultiGoalIntersectionEnv\n", + "import mediapy as media\n", + "\n", + "render = False\n", + "num_scenarios = 1000\n", + "start_seed = 100\n", + "goal_probabilities = {\n", + " \"right_turn\": 0.25,\n", + " \"left_turn\": 0.25,\n", + " \"go_straight\": 0.25,\n", + " \"u_turn\": 0.25\n", + "}\n", + "\n", + "\n", + "class MultiGoalWrapped(MultiGoalIntersectionEnv):\n", + " current_goal = None\n", + "\n", + " def step(self, actions):\n", + " o, r, tm, tc, i = super().step(actions)\n", + "\n", + " o = i['obs/goals/{}'.format(self.current_goal)]\n", + " r = i['reward/goals/{}'.format(self.current_goal)]\n", + " i['route_completion'] = i['route_completion/goals/{}'.format(self.current_goal)]\n", + " i['arrive_dest'] = i['arrive_dest/goals/{}'.format(self.current_goal)]\n", + " i['reward/goals/default'] = i['reward/goals/{}'.format(self.current_goal)]\n", + " i['route_completion/goals/default'] = i['route_completion/goals/{}'.format(self.current_goal)]\n", + " i['arrive_dest/goals/default'] = i['arrive_dest/goals/{}'.format(self.current_goal)]\n", + " i[\"current_goal\"] = self.current_goal\n", + " return o, r, tm, tc, i\n", + "\n", + " def reset(self, *args, **kwargs):\n", + " o, i = super().reset(*args, **kwargs)\n", + "\n", + " # Sample a goal from the goal set\n", + " if self.config[\"use_multigoal_intersection\"]:\n", + " p = goal_probabilities\n", + " self.current_goal = np.random.choice(list(p.keys()), p=list(p.values()))\n", + "\n", + " else:\n", + " self.current_goal = \"default\"\n", + "\n", + " o = i['obs/goals/{}'.format(self.current_goal)]\n", + " i['route_completion'] = i['route_completion/goals/{}'.format(self.current_goal)]\n", + " i['arrive_dest'] = i['arrive_dest/goals/{}'.format(self.current_goal)]\n", + " i['reward/goals/default'] = i['reward/goals/{}'.format(self.current_goal)]\n", + " i['route_completion/goals/default'] = i['route_completion/goals/{}'.format(self.current_goal)]\n", + " i['arrive_dest/goals/default'] = i['arrive_dest/goals/{}'.format(self.current_goal)]\n", + " i[\"current_goal\"] = self.current_goal\n", + "\n", + " return o, i" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "f5b6f059-52f8-46ee-bcfe-dee6f4d2e2e6", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[38;20m[INFO] Environment: MultiGoalWrapped\u001b[0m\n", + "\u001b[38;20m[INFO] MetaDrive version: 0.4.2.3\u001b[0m\n", + "\u001b[38;20m[INFO] Sensors: [lidar: Lidar(), side_detector: SideDetector(), lane_line_detector: LaneLineDetector()]\u001b[0m\n", + "\u001b[38;20m[INFO] Render Mode: none\u001b[0m\n", + "\u001b[38;20m[INFO] Horizon (Max steps per agent): 500\u001b[0m\n" + ] + } + ], + "source": [ + "\n", + "\n", + "env_config = dict(\n", + " use_render=render,\n", + " manual_control=False,\n", + " vehicle_config=dict(show_lidar=False, show_navi_mark=True, show_line_to_navi_mark=True,\n", + " show_line_to_dest=True, show_dest_mark=True),\n", + " horizon=500, # to speed up training\n", + "\n", + " traffic_density=0.06,\n", + " \n", + " use_multigoal_intersection=True, # Set to False if want to use the same observation but with original PG scenarios.\n", + " out_of_route_done=False,\n", + "\n", + " num_scenarios=num_scenarios,\n", + " start_seed=start_seed,\n", + " accident_prob=0.8,\n", + " crash_vehicle_done=False,\n", + " crash_object_done=False,\n", + ")\n", + "\n", + "wrapped = create_gym_wrapper(MultiGoalWrapped)\n", + "\n", + "env = wrapped(env_config)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ae2abe78-f3e3-40b9-88dd-a958fc932363", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[38;20m[INFO] Assets version: 0.4.2.3\u001b[0m\n", + "\u001b[38;20m[INFO] Known Pipes: glxGraphicsPipe\u001b[0m\n", + "\u001b[38;20m[INFO] Start Scenario Index: 100, Num Scenarios : 1000\u001b[0m\n", + "\u001b[33;20m[WARNING] env.vehicle will be deprecated soon. Use env.agent instead (base_env.py:731)\u001b[0m\n", + "\u001b[38;20m[INFO] Episode ended! Scenario Index: 606 Reason: arrive_dest.\u001b[0m\n" + ] + } + ], + "source": [ + "frames = []\n", + "\n", + "env.reset()\n", + "while True:\n", + " action = [0, 1]\n", + " o, r, d, i = env.step(action)\n", + " frame = env.render(mode=\"topdown\")\n", + " frames.append(frame)\n", + " if d:\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "40ac0392-67e3-4d2d-a9bd-2065831e43ca", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output at final step:\n", + "\tacceleration: 1.000\n", + "\tarrive_dest: 1.000\n", + "\tarrive_dest/goals/default: 1.000\n", + "\tarrive_dest/goals/go_straight: 1.000\n", + "\tarrive_dest/goals/left_turn: 0.000\n", + "\tarrive_dest/goals/right_turn: 0.000\n", + "\tarrive_dest/goals/u_turn: 0.000\n", + "\tcost: 0.000\n", + "\tcrash: 0.000\n", + "\tcrash_building: 0.000\n", + "\tcrash_human: 0.000\n", + "\tcrash_object: 0.000\n", + "\tcrash_sidewalk: 0.000\n", + "\tcrash_vehicle: 0.000\n", + "\tcurrent_goal: go_straight\n", + "\tenv_seed: 606.000\n", + "\tepisode_energy: 6.986\n", + "\tepisode_length: 88.000\n", + "\tepisode_reward: 35.834\n", + "\tmax_step: 0.000\n", + "\tnavigation_command: right\n", + "\tnavigation_forward: 0.000\n", + "\tnavigation_left: 0.000\n", + "\tnavigation_right: 1.000\n", + "\tout_of_road: 0.000\n", + "\tovertake_vehicle_num: 0.000\n", + "\tpolicy: EnvInputPolicy\n", + "\treward/default_reward: -10.000\n", + "\treward/goals/default: 12.335\n", + "\treward/goals/go_straight: 12.335\n", + "\treward/goals/left_turn: -10.000\n", + "\treward/goals/right_turn: -10.000\n", + "\treward/goals/u_turn: -10.000\n", + "\troute_completion: 0.969\n", + "\troute_completion/goals/default: 0.969\n", + "\troute_completion/goals/go_straight: 0.969\n", + "\troute_completion/goals/left_turn: 0.632\n", + "\troute_completion/goals/right_turn: 0.643\n", + "\troute_completion/goals/u_turn: 0.552\n", + "\tsteering: 0.000\n", + "\tstep_energy: 0.162\n", + "\tvelocity: 22.313\n" + ] + } + ], + "source": [ + "print(\"Output at final step:\")\n", + "\n", + "i = {k: i[k] for k in sorted(i.keys())}\n", + "for k, v in i.items():\n", + " if isinstance(v, str):\n", + " s = v\n", + " elif np.iterable(v):\n", + " continue\n", + " else:\n", + " s = \"{:.3f}\".format(v)\n", + " print(\"\\t{}: {}\".format(k, s))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "dc986e4e-f81c-4882-88b2-9eb306552fb3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "media.show_video(frames)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/metadrive/component/algorithm/BIG.py b/metadrive/component/algorithm/BIG.py index c322c472d..3384b3631 100644 --- a/metadrive/component/algorithm/BIG.py +++ b/metadrive/component/algorithm/BIG.py @@ -73,9 +73,17 @@ def generate(self, generate_method: str, parameter: Union[str, int]): assert isinstance(parameter, int), "When generating map by assigning block num, the parameter should be int" self.block_num = parameter + 1 elif generate_method == BigGenerateMethod.BLOCK_SEQUENCE: - assert isinstance(parameter, str), "When generating map from block sequence, the parameter should be a str" - self.block_num = len(parameter) + 1 - self._block_sequence = FirstPGBlock.ID + parameter + if isinstance(parameter, list): + self.block_num = len(parameter) + 1 + self._block_sequence = [FirstPGBlock] + parameter + else: + assert isinstance( + parameter, str + ), "When generating map from block sequence, the parameter should be a str. But got {}".format( + type(parameter) + ) + self.block_num = len(parameter) + 1 + self._block_sequence = FirstPGBlock.ID + parameter while True: if self.big_helper_func(): break @@ -104,8 +112,11 @@ def sample_block(self) -> PGBlock: block_type = self.np_random.choice(block_types, p=block_probabilities) block_type = get_metadrive_class(block_type) else: - type_id = self._block_sequence[len(self.blocks)] - block_type = self.block_dist_config.get_block(type_id) + if isinstance(self._block_sequence[0], str): + type_id = self._block_sequence[len(self.blocks)] + block_type = self.block_dist_config.get_block(type_id) + else: + block_type = self._block_sequence[len(self.blocks)] socket = self.np_random.choice(self.blocks[-1].get_socket_indices()) block = block_type( diff --git a/metadrive/component/algorithm/blocks_prob_dist.py b/metadrive/component/algorithm/blocks_prob_dist.py index bbf3e1b3a..86f4e179d 100644 --- a/metadrive/component/algorithm/blocks_prob_dist.py +++ b/metadrive/component/algorithm/blocks_prob_dist.py @@ -37,7 +37,8 @@ class PGBlockDistConfig: "Split": 0.00, "ParkingLot": 0.00, "TollGate": 0.00, - "Bidirection": 0.00 + "Bidirection": 0.00, + "StdInterSectionWithUTurn": 0.00 } @classmethod diff --git a/metadrive/component/map/pg_map.py b/metadrive/component/map/pg_map.py index 98ad32c36..b99940c36 100644 --- a/metadrive/component/map/pg_map.py +++ b/metadrive/component/map/pg_map.py @@ -20,7 +20,7 @@ def parse_map_config(easy_map_config, new_map_config, default_config): assert isinstance(default_config, Config) # Return the user specified config if overwritten - if not default_config["map_config"].is_identical(new_map_config): + if easy_map_config is None or not default_config["map_config"].is_identical(new_map_config): new_map_config = default_config["map_config"].copy(unchangeable=False).update(new_map_config) assert default_config["map"] == easy_map_config return new_map_config diff --git a/metadrive/component/navigation_module/base_navigation.py b/metadrive/component/navigation_module/base_navigation.py index 24c04daa9..9f09643cc 100644 --- a/metadrive/component/navigation_module/base_navigation.py +++ b/metadrive/component/navigation_module/base_navigation.py @@ -68,6 +68,7 @@ def __init__( self.navi_arrow_dir = [0, 0] self._dest_node_path = None self._goal_node_path = None + self._goal_node_path2 = None self._node_path_list = [] @@ -78,15 +79,20 @@ def __init__( # nodepath self._line_to_dest = self.origin.attachNewNode("line") self._goal_node_path = self.origin.attachNewNode("target") + self._goal_node_path2 = self.origin.attachNewNode("target2") self._dest_node_path = self.origin.attachNewNode("dest") self._node_path_list.append(self._line_to_dest) self._node_path_list.append(self._goal_node_path) + self._node_path_list.append(self._goal_node_path2) self._node_path_list.append(self._dest_node_path) if show_navi_mark: navi_point_model = AssetLoader.loader.loadModel(AssetLoader.file_path("models", "box.bam")) navi_point_model.reparentTo(self._goal_node_path) + + navi_point_model2 = AssetLoader.loader.loadModel(AssetLoader.file_path("models", "box.bam")) + navi_point_model2.reparentTo(self._goal_node_path2) if show_dest_mark: dest_point_model = AssetLoader.loader.loadModel(AssetLoader.file_path("models", "box.bam")) dest_point_model.reparentTo(self._dest_node_path) @@ -108,18 +114,20 @@ def __init__( line_seg.setColor(self.navi_mark_color[0], self.navi_mark_color[1], self.navi_mark_color[2], 1.0) line_seg.setThickness(4) self._dynamic_line_np_2 = NodePath(line_seg.create(True)) - self._node_path_list.append(self._dynamic_line_np_2) - self._dynamic_line_np_2.reparentTo(self.origin) self._line_to_navi = line_seg self._goal_node_path.setTransparency(TransparencyAttrib.M_alpha) + self._goal_node_path2.setTransparency(TransparencyAttrib.M_alpha) self._dest_node_path.setTransparency(TransparencyAttrib.M_alpha) self._goal_node_path.setColor( self.navi_mark_color[0], self.navi_mark_color[1], self.navi_mark_color[2], 0.7 ) + self._goal_node_path2.setColor( + self.navi_mark_color[0], self.navi_mark_color[1], self.navi_mark_color[2], 0.5 + ) self._dest_node_path.setColor( self.navi_mark_color[0], self.navi_mark_color[1], self.navi_mark_color[2], 0.7 ) @@ -180,6 +188,7 @@ def destroy(self): pass self._dest_node_path.removeNode() self._goal_node_path.removeNode() + self._goal_node_path2.removeNode() for np in self._node_path_list: np.detachNode() @@ -234,12 +243,16 @@ def _draw_line_to_dest(self, start_position, end_position): self._dynamic_line_np.hide(CamMask.Shadow | CamMask.RgbCam) self._dynamic_line_np.reparentTo(self.origin) - def _draw_line_to_navi(self, start_position, end_position): + def _draw_line_to_navi(self, start_position, end_position, next_checkpoint=None): if not self._show_line_to_navi_mark: return line_seg = self._line_to_navi line_seg.moveTo(panda_vector(start_position, self.LINE_TO_DEST_HEIGHT)) line_seg.drawTo(panda_vector(end_position, self.LINE_TO_DEST_HEIGHT)) + + if next_checkpoint is not None: + line_seg.drawTo(panda_vector(next_checkpoint, self.LINE_TO_DEST_HEIGHT)) + self._dynamic_line_np_2.removeNode() self._dynamic_line_np_2 = NodePath(line_seg.create(False)) diff --git a/metadrive/component/navigation_module/edge_network_navigation.py b/metadrive/component/navigation_module/edge_network_navigation.py index 9082bd84e..80bd992a8 100644 --- a/metadrive/component/navigation_module/edge_network_navigation.py +++ b/metadrive/component/navigation_module/edge_network_navigation.py @@ -100,27 +100,37 @@ def update_localization(self, ego_vehicle): self._navi_info.fill(0.0) half = self.CHECK_POINT_INFO_DIM - self._navi_info[:half], lanes_heading1, checkpoint = self._get_info_for_checkpoint( + self._navi_info[:half], lanes_heading1, next_checkpoint = self._get_info_for_checkpoint( lanes_id=0, ref_lane=self.map.road_network.get_lane(self.current_checkpoint_lane_index), ego_vehicle=ego_vehicle ) - self._navi_info[half:], lanes_heading2, _ = self._get_info_for_checkpoint( + self._navi_info[half:], lanes_heading2, next_next_checkpoint = self._get_info_for_checkpoint( lanes_id=1, ref_lane=self.map.road_network.get_lane(self.next_checkpoint_lane_index), ego_vehicle=ego_vehicle ) if self._show_navi_info: # Whether to visualize little boxes in the scene denoting the checkpoints - pos_of_goal = checkpoint + pos_of_goal = next_checkpoint self._goal_node_path.setPos(panda_vector(pos_of_goal[0], pos_of_goal[1], self.MARK_HEIGHT)) self._goal_node_path.setH(self._goal_node_path.getH() + 3) + + pos_of_goal = next_next_checkpoint + self._goal_node_path2.setPos(panda_vector(pos_of_goal[0], pos_of_goal[1], self.MARK_HEIGHT)) + self._goal_node_path2.setH(self._goal_node_path2.getH() + 3) + self.navi_arrow_dir = [lanes_heading1, lanes_heading2] dest_pos = self._dest_node_path.getPos() self._draw_line_to_dest(start_position=ego_vehicle.position, end_position=(dest_pos[0], dest_pos[1])) navi_pos = self._goal_node_path.getPos() - self._draw_line_to_navi(start_position=ego_vehicle.position, end_position=(navi_pos[0], navi_pos[1])) + next_navi_pos = self._goal_node_path2.getPos() + self._draw_line_to_navi( + start_position=ego_vehicle.position, + end_position=(navi_pos[0], navi_pos[1]), + next_checkpoint=(next_navi_pos[0], next_navi_pos[1]) + ) def _update_target_checkpoints(self, ego_lane_index) -> bool: """ diff --git a/metadrive/component/navigation_module/node_network_navigation.py b/metadrive/component/navigation_module/node_network_navigation.py index 596d08358..9b2fa81f4 100644 --- a/metadrive/component/navigation_module/node_network_navigation.py +++ b/metadrive/component/navigation_module/node_network_navigation.py @@ -40,7 +40,7 @@ def __init__( self.current_road = None self.next_road = None - def reset(self, vehicle): + def reset(self, vehicle, dest=None, random_seed=None): possible_lanes = ray_localization(vehicle.heading, vehicle.spawn_place, self.engine, use_heading_filter=False) possible_lane_indexes = [lane_index for lane, lane_index, dist in possible_lanes] @@ -56,11 +56,12 @@ def reset(self, vehicle): assert len(possible_lanes) > 0 lane, new_l_index = possible_lanes[0][:-1] - dest = vehicle.config["destination"] + if dest is None: + dest = vehicle.config["destination"] current_lane = lane destination = dest if dest is not None else None - random_seed = self.engine.global_random_seed + random_seed = self.engine.global_random_seed if random_seed is None else random_seed assert current_lane is not None, "spawn place is not on road!" super(NodeNetworkNavigation, self).reset(current_lane) assert self.map.road_network_type == NodeRoadNetwork, "This Navigation module only support NodeRoadNetwork type" @@ -188,12 +189,12 @@ def update_localization(self, ego_vehicle): self._navi_info.fill(0.0) half = self.CHECK_POINT_INFO_DIM # Put the next checkpoint's information into the first half of the navi_info - self._navi_info[:half], lanes_heading1, checkpoint = self._get_info_for_checkpoint( + self._navi_info[:half], lanes_heading1, next_checkpoint = self._get_info_for_checkpoint( lanes_id=0, ref_lane=self.current_ref_lanes[0], ego_vehicle=ego_vehicle ) # Put the next of the next checkpoint's information into the first half of the navi_info - self._navi_info[half:], lanes_heading2, _ = self._get_info_for_checkpoint( + self._navi_info[half:], lanes_heading2, next_next_checkpoint = self._get_info_for_checkpoint( lanes_id=1, ref_lane=self.next_ref_lanes[0] if self.next_ref_lanes is not None else self.current_ref_lanes[0], ego_vehicle=ego_vehicle @@ -202,13 +203,23 @@ def update_localization(self, ego_vehicle): self.navi_arrow_dir = [lanes_heading1, lanes_heading2] if self._show_navi_info: # Whether to visualize little boxes in the scene denoting the checkpoints - pos_of_goal = checkpoint + pos_of_goal = next_checkpoint self._goal_node_path.setPos(panda_vector(pos_of_goal[0], pos_of_goal[1], self.MARK_HEIGHT)) self._goal_node_path.setH(self._goal_node_path.getH() + 3) + + pos_of_goal = next_next_checkpoint + self._goal_node_path2.setPos(panda_vector(pos_of_goal[0], pos_of_goal[1], self.MARK_HEIGHT)) + self._goal_node_path2.setH(self._goal_node_path2.getH() + 3) + dest_pos = self._dest_node_path.getPos() self._draw_line_to_dest(start_position=ego_vehicle.position, end_position=(dest_pos[0], dest_pos[1])) navi_pos = self._goal_node_path.getPos() - self._draw_line_to_navi(start_position=ego_vehicle.position, end_position=(navi_pos[0], navi_pos[1])) + next_navi_pos = self._goal_node_path2.getPos() + self._draw_line_to_navi( + start_position=ego_vehicle.position, + end_position=(navi_pos[0], navi_pos[1]), + next_checkpoint=(next_navi_pos[0], next_navi_pos[1]) + ) def _update_target_checkpoints(self, ego_lane_index, ego_lane_longitude) -> bool: """ @@ -250,7 +261,9 @@ def get_current_lateral_range(self, current_position, engine) -> float: def _get_current_lane(self, ego_vehicle): """ - Called in update_localization to find current lane information + Called in update_localization to find current lane information. If the vehicle is in the current reference lane, + meaning it is not yet moving to the next road segment, then return the current reference lane. Otherwise, return + the next reference lane. If the vehicle is not in any of the reference lanes, then return the closest lane. """ possible_lanes, on_lane = ray_localization( ego_vehicle.heading, ego_vehicle.position, ego_vehicle.engine, return_on_lane=True diff --git a/metadrive/component/pgblock/intersection.py b/metadrive/component/pgblock/intersection.py index 8d244d12b..5e91407eb 100644 --- a/metadrive/component/pgblock/intersection.py +++ b/metadrive/component/pgblock/intersection.py @@ -110,7 +110,6 @@ def _create_part(self, attach_lanes, attach_road: Road, radius: float, intersect # u-turn if self._enable_u_turn_flag: - adverse_road = -attach_road self._create_u_turn(attach_road, part_idx) # go forward part @@ -221,6 +220,17 @@ def _create_left_turn(self, radius, lane_num, attach_left_lane, attach_road, int ) def _create_u_turn(self, attach_road, part_idx): + """ + Create a U turn. + + Args: + attach_road: the road where the U turn starts. + part_idx: in [0, 1, 2, 3]. When part_idx!=0, we grab the lanes from road network. Otherwise we use the + initial lanes (positive_lanes). + + Returns: + None. + """ # set to CONTINUOUS to debug line_type = PGLineType.NONE lanes = attach_road.get_lanes(self.block_network) if part_idx != 0 else self.positive_lanes @@ -253,3 +263,9 @@ def get_intermediate_spawn_lanes(self): """Override this function for intersection so that we won't spawn vehicles in the center of intersection.""" respawn_lanes = self.get_respawn_lanes() return respawn_lanes + + +class InterSectionWithUTurn(InterSection): + ID = "U" + _enable_u_turn_flag = True + SOCKET_NUM = 4 diff --git a/metadrive/component/pgblock/pg_block.py b/metadrive/component/pgblock/pg_block.py index c717d7073..c2f1af836 100644 --- a/metadrive/component/pgblock/pg_block.py +++ b/metadrive/component/pgblock/pg_block.py @@ -252,10 +252,15 @@ def create_in_world(self): for _id, lane in enumerate(lanes): self._construct_lane(lane, (_from, _to, _id)) + + # choose_side is a two-elemental list, the first element is for left side, + # the second element is for right side. If False, then the left/right side line (broken line or + # continuous line) will not be constructed. + choose_side = [True, True] if _id == len(lanes) - 1 else [True, False] - if Road(_from, _to).is_negative_road() and _id == 0: - # draw center line with positive road - choose_side = [False, False] + # if Road(_from, _to).is_negative_road() and _id == 0: + # # draw center line with positive road + # choose_side = [False, False] self._construct_lane_line_in_block(lane, choose_side) self._construct_sidewalk() self._construct_crosswalk() diff --git a/metadrive/component/pgblock/std_intersection.py b/metadrive/component/pgblock/std_intersection.py index 2e87792fb..4aac8171d 100644 --- a/metadrive/component/pgblock/std_intersection.py +++ b/metadrive/component/pgblock/std_intersection.py @@ -1,5 +1,5 @@ -from metadrive.component.pgblock.intersection import InterSection from metadrive.component.pg_space import Parameter +from metadrive.component.pgblock.intersection import InterSection, InterSectionWithUTurn class StdInterSection(InterSection): @@ -7,3 +7,10 @@ def _try_plug_into_previous_block(self) -> bool: self._config[Parameter.change_lane_num] = 0 success = super(StdInterSection, self)._try_plug_into_previous_block() return success + + +class StdInterSectionWithUTurn(InterSectionWithUTurn): + def _try_plug_into_previous_block(self) -> bool: + self._config[Parameter.change_lane_num] = 0 + success = super(StdInterSectionWithUTurn, self)._try_plug_into_previous_block() + return success diff --git a/metadrive/component/road_network/node_road_network.py b/metadrive/component/road_network/node_road_network.py index 7ecf8b6a3..85f02c642 100644 --- a/metadrive/component/road_network/node_road_network.py +++ b/metadrive/component/road_network/node_road_network.py @@ -272,6 +272,7 @@ def shortest_path(self, start: str, goal: str) -> List[str]: Returns: The shortest checkpoints from start to goal. """ + assert isinstance(goal, str) start_road_node = start[0] assert start != goal return next(self.bfs_paths(start_road_node, goal), []) diff --git a/metadrive/component/sensors/distance_detector.py b/metadrive/component/sensors/distance_detector.py index 8261496f8..eae9cce4f 100644 --- a/metadrive/component/sensors/distance_detector.py +++ b/metadrive/component/sensors/distance_detector.py @@ -89,7 +89,6 @@ class DistanceDetector(BaseSensor): """ It is a module like lidar, used to detect sidewalk/center line or other static things """ - Lidar_point_cloud_obs_dim = 240 DEFAULT_HEIGHT = 0.2 # for vis debug @@ -196,7 +195,7 @@ def __init__(self, engine): super(SideDetector, self).__init__(engine) self.set_start_phase_offset(90) self.origin.hide(CamMask.RgbCam | CamMask.Shadow | CamMask.Shadow | CamMask.DepthCam | CamMask.SemanticCam) - self.mask = CollisionGroup.ContinuousLaneLine + self.mask = CollisionGroup.ContinuousLaneLine | CollisionGroup.Sidewalk class LaneLineDetector(SideDetector): @@ -206,4 +205,4 @@ def __init__(self, engine): super(SideDetector, self).__init__(engine) self.set_start_phase_offset(90) self.origin.hide(CamMask.RgbCam | CamMask.Shadow | CamMask.Shadow | CamMask.DepthCam | CamMask.SemanticCam) - self.mask = CollisionGroup.ContinuousLaneLine | CollisionGroup.BrokenLaneLine + self.mask = CollisionGroup.ContinuousLaneLine | CollisionGroup.BrokenLaneLine | CollisionGroup.Sidewalk diff --git a/metadrive/component/sensors/lidar.py b/metadrive/component/sensors/lidar.py index 6722eaa61..a68d49f90 100644 --- a/metadrive/component/sensors/lidar.py +++ b/metadrive/component/sensors/lidar.py @@ -15,7 +15,6 @@ class Lidar(DistanceDetector): ANGLE_FACTOR = True - Lidar_point_cloud_obs_dim = 240 DEFAULT_HEIGHT = 1.2 BROAD_PHASE_EXTRA_DIST = 0 diff --git a/metadrive/component/vehicle/base_vehicle.py b/metadrive/component/vehicle/base_vehicle.py index f6890b89a..071a69d18 100644 --- a/metadrive/component/vehicle/base_vehicle.py +++ b/metadrive/component/vehicle/base_vehicle.py @@ -232,9 +232,26 @@ def before_step(self, action=None): return step_info def after_step(self): - step_info = {} if self.navigation and self.config["navigation_module"]: self.navigation.update_localization(self) + self._state_check() + self.update_dist_to_left_right() + step_energy, episode_energy = self._update_energy_consumption() + self.out_of_route = self._out_of_route() + step_info = self._update_overtake_stat() + my_policy = self.engine.get_policy(self.name) + step_info.update( + { + "velocity": float(self.speed), + "steering": float(self.steering), + "acceleration": float(self.throttle_brake), + "step_energy": step_energy, + "episode_energy": episode_energy, + "policy": my_policy.name if my_policy is not None else my_policy + } + ) + + if self.navigation is not None and hasattr(self.navigation, "navi_arrow_dir"): lanes_heading = self.navigation.navi_arrow_dir lane_0_heading = lanes_heading[0] lane_1_heading = lanes_heading[1] @@ -258,22 +275,7 @@ def after_step(self): "navigation_right": navigation_turn_right } ) - self._state_check() - self.update_dist_to_left_right() - step_energy, episode_energy = self._update_energy_consumption() - self.out_of_route = self._out_of_route() - step_info.update(self._update_overtake_stat()) - my_policy = self.engine.get_policy(self.name) - step_info.update( - { - "velocity": float(self.speed), - "steering": float(self.steering), - "acceleration": float(self.throttle_brake), - "step_energy": step_energy, - "episode_energy": episode_energy, - "policy": my_policy.name if my_policy is not None else my_policy - } - ) + return step_info def _out_of_route(self): @@ -512,14 +514,15 @@ def _apply_throttle_brake(self, throttle_brake): def update_dist_to_left_right(self): self.dist_to_left_side, self.dist_to_right_side = self._dist_to_route_left_right() - def _dist_to_route_left_right(self): - # TODO - if self.navigation is None or self.navigation.current_ref_lanes is None: + def _dist_to_route_left_right(self, navigation=None): + if navigation is None: + navigation = self.navigation + if navigation is None or navigation.current_ref_lanes is None: return 0, 0 - current_reference_lane = self.navigation.current_ref_lanes[0] + current_reference_lane = navigation.current_ref_lanes[0] _, lateral_to_reference = current_reference_lane.local_coordinates(self.position) - lateral_to_left = lateral_to_reference + self.navigation.get_current_lane_width() / 2 - lateral_to_right = self.navigation.get_current_lateral_range(self.position, self.engine) - lateral_to_left + lateral_to_left = lateral_to_reference + navigation.get_current_lane_width() / 2 + lateral_to_right = navigation.get_current_lateral_range(self.position, self.engine) - lateral_to_left return lateral_to_left, lateral_to_right # @property diff --git a/metadrive/engine/interface.py b/metadrive/engine/interface.py index 3961f0b8d..9c8842786 100644 --- a/metadrive/engine/interface.py +++ b/metadrive/engine/interface.py @@ -188,6 +188,8 @@ def destroy(self): self.left_panel.destroy() def _update_navi_arrow(self, lanes_heading): + if not self.engine.global_config["vehicle_config"]["show_navigation_arrow"]: + return lane_0_heading = lanes_heading[0] lane_1_heading = lanes_heading[1] if abs(lane_0_heading - lane_1_heading) < 0.01: diff --git a/metadrive/envs/base_env.py b/metadrive/envs/base_env.py index 8846238d9..4102fe914 100644 --- a/metadrive/envs/base_env.py +++ b/metadrive/envs/base_env.py @@ -120,6 +120,8 @@ show_line_to_dest=False, # Whether to draw a line from current vehicle position to the next navigation point show_line_to_navi_mark=False, + # Whether to draw left / right arrow in the interface to denote the navigation direction + show_navigation_arrow=True, # If set to True, the vehicle will be in color green in top-down renderer or MARL setting use_special_color=False, # Clear wheel friction, so it can not move by setting steering and throttle/brake. Used for ReplayPolicy diff --git a/metadrive/envs/metadrive_env.py b/metadrive/envs/metadrive_env.py index 29c445306..fea4c2ec6 100644 --- a/metadrive/envs/metadrive_env.py +++ b/metadrive/envs/metadrive_env.py @@ -72,6 +72,7 @@ out_of_road_penalty=5.0, crash_vehicle_penalty=5.0, crash_object_penalty=5.0, + crash_sidewalk_penalty=0.0, driving_reward=1.0, speed_reward=0.1, use_lateral_reward=False, @@ -83,6 +84,7 @@ # ===== Termination Scheme ===== out_of_route_done=False, + out_of_road_done=True, on_continuous_line_done=True, on_broken_line_done=False, crash_vehicle_done=True, @@ -160,7 +162,7 @@ def done_function(self, vehicle_id: str): "Episode ended! Scenario Index: {} Reason: arrive_dest.".format(self.current_seed), extra={"log_once": True} ) - if done_info[TerminationState.OUT_OF_ROAD]: + if done_info[TerminationState.OUT_OF_ROAD] and self.config["out_of_road_done"]: done = True self.logger.info( "Episode ended! Scenario Index: {} Reason: out_of_road.".format(self.current_seed), @@ -280,7 +282,8 @@ def reward_function(self, vehicle_id: str): reward = -self.config["crash_vehicle_penalty"] elif vehicle.crash_object: reward = -self.config["crash_object_penalty"] - + elif vehicle.crash_sidewalk: + reward = -self.config["crash_sidewalk_penalty"] step_info["route_completion"] = vehicle.navigation.route_completion return reward, step_info diff --git a/metadrive/envs/multigoal_intersection.py b/metadrive/envs/multigoal_intersection.py new file mode 100644 index 000000000..5ca0c655e --- /dev/null +++ b/metadrive/envs/multigoal_intersection.py @@ -0,0 +1,616 @@ +""" +This file provides a multi-goal environment based on the intersection environment. The environment fully support +conventional MetaDrive PG maps, where there is a special config['use_pg_map'] to enable the PG maps and all config are +the same as MetaDriveEnv. +If config['use_pg_map'] is False, the environment will use an intersection map and the goals information for all +possible destinations will be provided. +""" +from collections import defaultdict + +import gymnasium as gym +import numpy as np +import seaborn as sns + +from metadrive.component.navigation_module.node_network_navigation import NodeNetworkNavigation +from metadrive.component.pg_space import ParameterSpace, Parameter, DiscreteSpace, BoxSpace +from metadrive.component.pgblock.first_block import FirstPGBlock +from metadrive.component.pgblock.intersection import InterSectionWithUTurn +from metadrive.component.road_network import Road +from metadrive.constants import DEFAULT_AGENT +from metadrive.engine.logger import get_logger +from metadrive.envs.metadrive_env import MetaDriveEnv +from metadrive.manager.base_manager import BaseManager +from metadrive.obs.state_obs import BaseObservation, StateObservation +from metadrive.utils.math import clip, norm + +logger = get_logger() + +EGO_STATE_DIM = 5 +NAVI_DIM = 10 +GOAL_DEPENDENT_STATE_DIM = 3 + + +class CustomizedObservation(BaseObservation): + def __init__(self, config): + self.state_obs = StateObservation(config) + super(CustomizedObservation, self).__init__(config) + self.latest_observation = {} + + self.lane_detect_dim = self.config['vehicle_config']['lane_line_detector']['num_lasers'] + self.side_detect_dim = self.config['vehicle_config']['side_detector']['num_lasers'] + self.vehicle_detect_dim = self.config['vehicle_config']['lidar']['num_lasers'] + + @property + def observation_space(self): + shape = ( + EGO_STATE_DIM + self.side_detect_dim + self.lane_detect_dim + self.vehicle_detect_dim + NAVI_DIM + + GOAL_DEPENDENT_STATE_DIM, + ) + return gym.spaces.Box(-1.0, 1.0, shape=shape, dtype=np.float32) + + def observe(self, vehicle, navigation=None): + ego = self.state_observe(vehicle) + assert ego.shape[0] == EGO_STATE_DIM + + obs = [ego] + + if vehicle.config["side_detector"]["num_lasers"] > 0: + side = self.side_detector_observe(vehicle) + assert side.shape[0] == self.side_detect_dim + obs.append(side) + self.latest_observation["side_detect"] = side + + if vehicle.config["lane_line_detector"]["num_lasers"] > 0: + lane = self.lane_line_detector_observe(vehicle) + assert lane.shape[0] == self.lane_detect_dim + obs.append(lane) + self.latest_observation["lane_detect"] = lane + + if vehicle.config["lidar"]["num_lasers"] > 0: + veh = self.vehicle_detector_observe(vehicle) + assert veh.shape[0] == self.vehicle_detect_dim + obs.append(veh) + self.latest_observation["vehicle_detect"] = veh + if navigation is None: + navigation = vehicle.navigation + navi = navigation.get_navi_info() + assert len(navi) == NAVI_DIM + obs.append(navi) + + # Goal-dependent infos + goal_dependent_info = [] + lateral_to_left, lateral_to_right = vehicle._dist_to_route_left_right(navigation=navigation) + if self.engine.current_map: + total_width = float((self.engine.current_map.MAX_LANE_NUM + 1) * self.engine.current_map.MAX_LANE_WIDTH) + else: + total_width = 100 + lateral_to_left /= total_width + lateral_to_right /= total_width + goal_dependent_info += [clip(lateral_to_left, 0.0, 1.0), clip(lateral_to_right, 0.0, 1.0)] + current_reference_lane = navigation.current_ref_lanes[-1] + goal_dependent_info += [ + # The angular difference between vehicle's heading and the lane heading at this location. + vehicle.heading_diff(current_reference_lane), + ] + goal_dependent_info = np.asarray(goal_dependent_info) + assert goal_dependent_info.shape[0] == GOAL_DEPENDENT_STATE_DIM + obs.append(goal_dependent_info) + + obs = np.concatenate(obs) + + self.latest_observation["state"] = ego + self.latest_observation["raw_navi"] = navi + + return obs + + def state_observe(self, vehicle): + # update out of road + info = np.zeros([ + EGO_STATE_DIM, + ]) + + # The velocity of target vehicle + info[0] = clip((vehicle.speed_km_h + 1) / (vehicle.max_speed_km_h + 1), 0.0, 1.0) + + # Current steering + info[1] = clip((vehicle.steering / vehicle.MAX_STEERING + 1) / 2, 0.0, 1.0) + + # The normalized actions at last steps + info[2] = clip((vehicle.last_current_action[1][0] + 1) / 2, 0.0, 1.0) + info[3] = clip((vehicle.last_current_action[1][1] + 1) / 2, 0.0, 1.0) + + # Current angular acceleration (yaw rate) + heading_dir_last = vehicle.last_heading_dir + heading_dir_now = vehicle.heading + cos_beta = heading_dir_now.dot(heading_dir_last) / (norm(*heading_dir_now) * norm(*heading_dir_last)) + beta_diff = np.arccos(clip(cos_beta, 0.0, 1.0)) + yaw_rate = beta_diff / 0.1 + info[4] = clip(yaw_rate, 0.0, 1.0) + + return info + + def side_detector_observe(self, vehicle): + return np.asarray( + self.engine.get_sensor("side_detector").perceive( + vehicle, + num_lasers=vehicle.config["side_detector"]["num_lasers"], + distance=vehicle.config["side_detector"]["distance"], + physics_world=vehicle.engine.physics_world.static_world, + show=vehicle.config["show_side_detector"], + ).cloud_points + ) + + def lane_line_detector_observe(self, vehicle): + return np.asarray( + self.engine.get_sensor("lane_line_detector").perceive( + vehicle, + vehicle.engine.physics_world.static_world, + num_lasers=vehicle.config["lane_line_detector"]["num_lasers"], + distance=vehicle.config["lane_line_detector"]["distance"], + show=vehicle.config["show_lane_line_detector"], + ).cloud_points + ) + + def vehicle_detector_observe(self, vehicle): + cloud_points, detected_objects = self.engine.get_sensor("lidar").perceive( + vehicle, + physics_world=self.engine.physics_world.dynamic_world, + num_lasers=vehicle.config["lidar"]["num_lasers"], + distance=vehicle.config["lidar"]["distance"], + show=vehicle.config["show_lidar"], + ) + return np.asarray(cloud_points) + + def destroy(self): + """ + Clear allocated memory + """ + self.state_obs.destroy() + super(CustomizedObservation, self).destroy() + self.cloud_points = None + self.detected_objects = None + + +class CustomizedIntersection(InterSectionWithUTurn): + PARAMETER_SPACE = ParameterSpace( + { + Parameter.radius: BoxSpace(min=9, max=20.0), + Parameter.change_lane_num: DiscreteSpace(min=0, max=2), + Parameter.decrease_increase: DiscreteSpace(min=0, max=0) + } + ) + + +class MultiGoalIntersectionNavigationManager(BaseManager): + """ + This manager is responsible for managing multiple navigation modules, each of which is responsible for guiding the + agent to a specific goal. + """ + GOALS = { + "u_turn": (-Road(FirstPGBlock.NODE_2, FirstPGBlock.NODE_3)).end_node, + "right_turn": Road( + CustomizedIntersection.node(block_idx=1, part_idx=0, road_idx=0), + CustomizedIntersection.node(block_idx=1, part_idx=0, road_idx=1) + ).end_node, + "go_straight": Road( + CustomizedIntersection.node(block_idx=1, part_idx=1, road_idx=0), + CustomizedIntersection.node(block_idx=1, part_idx=1, road_idx=1) + ).end_node, + "left_turn": Road( + CustomizedIntersection.node(block_idx=1, part_idx=2, road_idx=0), + CustomizedIntersection.node(block_idx=1, part_idx=2, road_idx=1) + ).end_node, + } + + def __init__(self): + super().__init__() + config = self.engine.global_config + vehicle_config = config["vehicle_config"] + self.navigations = {} + navi = NodeNetworkNavigation + colors = sns.color_palette("colorblind") + for c, (dest_name, road) in enumerate(self.GOALS.items()): + self.navigations[dest_name] = navi( + # self.engine, + show_navi_mark=vehicle_config["show_navi_mark"], + show_dest_mark=vehicle_config["show_dest_mark"], + show_line_to_dest=vehicle_config["show_line_to_dest"], + panda_color=colors[c], # color for navigation marker + name=dest_name, + vehicle_config=vehicle_config + ) + + @property + def agent(self): + return self.engine.agents[DEFAULT_AGENT] + + @property + def goals(self): + return self.GOALS + + def after_reset(self): + """Reset all navigation modules.""" + # print("[DEBUG]: after_reset in MultiGoalIntersectionNavigationManager") + for name, navi in self.navigations.items(): + navi.reset(self.agent, dest=self.goals[name]) + navi.update_localization(self.agent) + + def after_step(self): + """Update all navigation modules.""" + # print("[DEBUG]: after_step in MultiGoalIntersectionNavigationManager") + for name, navi in self.navigations.items(): + navi.update_localization(self.agent) + # print("Navigation {} next checkpoint: {}".format(name, navi.get_checkpoints())) + + def get_navigation(self, goal_name): + """Return the navigation module for the given goal.""" + assert goal_name in self.goals, "Invalid goal name!" + return self.navigations[goal_name] + + +class MultiGoalIntersectionEnv(MetaDriveEnv): + """ + This environment is an intersection with multiple goals. We provide the reward function, observation, termination + conditions for each goal in the info dict returned by env.reset and env.step, with prefix "goals/{goal_name}/". + """ + @classmethod + def default_config(cls): + config = MetaDriveEnv.default_config() + # config.update(VaryingDynamicsConfig) + config.update( + { + "use_multigoal_intersection": True, + + # Set the map to an Intersection + "start_seed": 0, + + # Even though the map will not change, the traffic flow will change. + "num_scenarios": 1000, + + # Remove all traffic vehicles for now. + # "traffic_density": 0.2, + + # If the vehicle does not reach the default destination, it will receive a penalty. + "wrong_way_penalty": 10.0, + # "crash_sidewalk_penalty": 10.0, + # "crash_vehicle_penalty": 10.0, + # "crash_object_penalty": 10.0, + # "out_of_road_penalty": 10.0, + "out_of_route_penalty": 0.0, + # "success_reward": 10.0, + # "driving_reward": 1.0, + # "on_continuous_line_done": True, + # "out_of_road_done": True, + "vehicle_config": { + + # Remove navigation arrows in the window as we are in multi-goal environment. + "show_navigation_arrow": False, + + # Turn off vehicle's own navigation module. + "side_detector": dict(num_lasers=120, distance=50), # laser num, distance + "lidar": dict(num_lasers=120, distance=50), + + # To avoid goal-dependent lane detection, we use Lidar to detect distance to nearby lane lines. + # Otherwise, we will ask the navigation module to provide current lane and extract the lateral + # distance directly on this lane. + "lane_line_detector": dict(num_lasers=0, distance=20) + } + } + ) + return config + + def _post_process_config(self, config): + config = super()._post_process_config(config) + if config["use_multigoal_intersection"]: + config['map'] = None + config['map_config'] = dict( + type="block_sequence", config=[ + CustomizedIntersection, + ], lane_num=2, lane_width=3.5 + ) + return config + + # def _get_agent_manager(self): + # return VaryingDynamicsAgentManager(init_observations=self._get_observations()) + + def get_single_observation(self): + return CustomizedObservation(self.config) + + # else: + # return super().get_single_observation() + # img_obs = self.config["image_observation"] + # o = ImageStateObservation(self.config) if img_obs else LidarStateObservation(self.config) + + def setup_engine(self): + super().setup_engine() + + # Introducing a new navigation manager + if self.config["use_multigoal_intersection"]: + self.engine.register_manager("goal_manager", MultiGoalIntersectionNavigationManager()) + + def _get_step_return(self, actions, engine_info): + """Add goal-dependent observation to the info dict.""" + o, r, tm, tc, i = super(MultiGoalIntersectionEnv, self)._get_step_return(actions, engine_info) + + if self.config["use_multigoal_intersection"]: + for goal_name in self.engine.goal_manager.goals.keys(): + navi = self.engine.goal_manager.get_navigation(goal_name) + goal_obs = self.observations["default_agent"].observe(self.agents[DEFAULT_AGENT], navi) + i["obs/goals/{}".format(goal_name)] = goal_obs + assert r == i["reward/default_reward"] + assert i["route_completion"] == i["route_completion/goals/default"] + + else: + i["obs/goals/default"] = self.observations["default_agent"].observe(self.agents[DEFAULT_AGENT]) + return o, r, tm, tc, i + + def _get_reset_return(self, reset_info): + """Add goal-dependent observation to the info dict.""" + o, i = super(MultiGoalIntersectionEnv, self)._get_reset_return(reset_info) + + if self.config["use_multigoal_intersection"]: + for goal_name in self.engine.goal_manager.goals.keys(): + navi = self.engine.goal_manager.get_navigation(goal_name) + goal_obs = self.observations["default_agent"].observe(self.agents[DEFAULT_AGENT], navi) + i["obs/goals/{}".format(goal_name)] = goal_obs + + else: + i["obs/goals/default"] = self.observations["default_agent"].observe(self.agents[DEFAULT_AGENT]) + + return o, i + + def _reward_per_navigation(self, vehicle, navi, goal_name): + """Compute the reward for the given goal. goal_name='default' means we use the vehicle's own navigation.""" + reward = 0.0 + + # Get goal-dependent information + if navi.current_lane in navi.current_ref_lanes: + current_lane = navi.current_lane + positive_road = 1 + else: + current_lane = navi.current_ref_lanes[0] + current_road = navi.current_road + positive_road = 1 if not current_road.is_negative_road() else -1 + long_last, _ = current_lane.local_coordinates(vehicle.last_position) + long_now, lateral_now = current_lane.local_coordinates(vehicle.position) + + # Reward for moving forward in current lane + reward += self.config["driving_reward"] * (long_now - long_last) * positive_road + + left, right = vehicle._dist_to_route_left_right(navigation=navi) + out_of_route = (right < 0) or (left < 0) + + # Reward for speed, sign determined by whether in the correct lanes (instead of driving in the wrong + # direction). + reward += self.config["speed_reward"] * (vehicle.speed_km_h / vehicle.max_speed_km_h) * positive_road + if self._is_arrive_destination(vehicle): + if self._is_arrive_destination(vehicle, goal_name): + reward += self.config["success_reward"] + else: + # if goal_name == "default": + # print("WRONG WAY") + reward = -self.config["wrong_way_penalty"] + else: + if self._is_out_of_road(vehicle): + reward = -self.config["out_of_road_penalty"] + elif vehicle.crash_vehicle: + reward = -self.config["crash_vehicle_penalty"] + elif vehicle.crash_object: + reward = -self.config["crash_object_penalty"] + elif vehicle.crash_sidewalk: + reward = -self.config["crash_sidewalk_penalty"] + elif out_of_route: + # if goal_name == "default": + # print("OUT OF ROUTE") + reward = -self.config["out_of_route_penalty"] + + return reward, navi.route_completion + + def reward_function(self, vehicle_id: str): + """ + Compared to the original reward_function, we add goal-dependent reward to info dict. + """ + vehicle = self.agents[vehicle_id] + step_info = dict() + + # Compute goal-dependent reward and saved to step_info + if self.config["use_multigoal_intersection"]: + for goal_name in self.engine.goal_manager.goals.keys(): + navi = self.engine.goal_manager.get_navigation(goal_name) + prefix = goal_name + reward, route_completion = self._reward_per_navigation(vehicle, navi, goal_name) + step_info[f"reward/goals/{prefix}"] = reward + step_info[f"route_completion/goals/{prefix}"] = route_completion + + else: + navi = vehicle.navigation + goal_name = "default" + reward, route_completion = self._reward_per_navigation(vehicle, navi, goal_name) + step_info[f"reward/goals/{goal_name}"] = reward + step_info[f"route_completion/goals/{goal_name}"] = route_completion + + default_reward, default_rc = self._reward_per_navigation(vehicle, vehicle.navigation, "default") + step_info[f"reward/goals/default"] = default_reward + step_info[f"route_completion/goals/default"] = default_rc + step_info[f"reward/default_reward"] = default_reward + step_info[f"route_completion"] = vehicle.navigation.route_completion + + return default_reward, step_info + + def _is_arrive_destination(self, vehicle, goal_name=None): + """ + Compared to the original function, here we look up the navigation from goal_manager. + + Args: + vehicle: The BaseVehicle instance. + goal_name: The name of the goal. If None, return True if any goal is arrived. + + Returns: + flag: Whether this vehicle arrives its destination. + """ + + if self.config["use_multigoal_intersection"]: + if goal_name is None: + ret = False + for name in self.engine.goal_manager.goals.keys(): + ret = ret or self._is_arrive_destination(vehicle, name) + return ret + + if goal_name == "default": + navi = self.vehicle.navigation + else: + navi = self.engine.goal_manager.get_navigation(goal_name) + + else: + navi = vehicle.navigation + + long, lat = navi.final_lane.local_coordinates(vehicle.position) + flag = (navi.final_lane.length - 5 < long < navi.final_lane.length + 5) and ( + navi.get_current_lane_width() / 2 >= lat >= + (0.5 - navi.get_current_lane_num()) * navi.get_current_lane_width() + ) + return flag + + def done_function(self, vehicle_id: str): + """ + Compared to MetaDriveEnv's done_function, we add more stats here to record which goal is arrived. + """ + done, done_info = super(MultiGoalIntersectionEnv, self).done_function(vehicle_id) + vehicle = self.agents[vehicle_id] + + if self.config["use_multigoal_intersection"]: + for goal_name in self.engine.goal_manager.goals.keys(): + done_info[f"arrive_dest/goals/{goal_name}"] = self._is_arrive_destination(vehicle, goal_name) + + else: + done_info[f"arrive_dest/goals/default"] = done + + done_info["arrive_dest/goals/default"] = self._is_arrive_destination(vehicle, "default") + + return done, done_info + + +if __name__ == "__main__": + config = dict( + use_render=True, + manual_control=True, + vehicle_config=dict( + show_navi_mark=True, + show_line_to_navi_mark=True, + show_lidar=False, + show_side_detector=True, + show_lane_line_detector=True, + ), + + # ******************************************** + use_multigoal_intersection=False + # ******************************************** + + # **{ + # "map_config": dict( + # lane_num=5, + # lane_width=3.5 + # ), + # } + ) + env = MultiGoalIntersectionEnv(config) + episode_rewards = defaultdict(float) + try: + o, info = env.reset() + + # default_ckpt = env.vehicle.navigation.checkpoints[-1] + # for goal, navi in env.engine.goal_manager.navigations.items(): + # if navi.checkpoints[-1] == default_ckpt: + # break + # assert np.all(o == info["obs/goals/{}".format(goal)]) + + goal = "default" + + print('=======================') + print("Full observation shape:\n\t", o.shape) + print("Goal-agnostic observation shape:\n\t", {k: v.shape for k, v in info.items() if k.startswith("obs/ego")}) + print("Observation shape for each goals: ") + for k in sorted(info.keys()): + if k.startswith("obs/goals/"): + print(f"\t{k}: {info[k].shape}") + print('=======================') + + obs_recorder = defaultdict(list) + + s = 0 + for i in range(1, 1000000000): + o, r, tm, tc, info = env.step([0, 1]) + + assert np.all(o == info["obs/goals/{}".format(goal)]) + assert np.all(r == info["reward/goals/{}".format(goal)]) + + done = tm or tc + s += 1 + # env.render() + env.render(mode="topdown") + + for k in info.keys(): + if k.startswith("obs/goals"): + obs_recorder[k].append(info[k]) + + for k, v in info.items(): + if k.startswith("reward/goals"): + episode_rewards[k] += v + + if s % 20 == 0: + print('\n===== timestep {} ====='.format(s)) + print('goal: ', goal) + print('route completion:') + for k in sorted(info.keys()): + if k.startswith("route_completion/goals/"): + print(f"\t{k}: {info[k]:.2f}") + + print('\nreward:') + for k in sorted(info.keys()): + if k.startswith("reward/"): + print(f"\t{k}: {info[k]:.2f}") + print('=======================') + + if done: + print('\n===== timestep {} ====='.format(s)) + print("EPISODE DONE\n") + print('route completion:') + for k in sorted(info.keys()): + # kk = k.replace("/route_completion", "") + if k.startswith("route_completion/goals/"): + print(f"\t{k}: {info[k]:.2f}") + + print('\narrive destination (success):') + for k in sorted(info.keys()): + # kk = k.replace("/arrive_dest", "") + if k.startswith("arrive_dest/goals/"): + print(f"\t{k}: {info[k]:.2f}") + + print('\nepisode_rewards:') + for k in sorted(episode_rewards.keys()): + # kk = k.replace("/step_reward", "") + print(f"\t{k}: {episode_rewards[k]:.2f}") + episode_rewards.clear() + print('=======================') + + if done: + + import numpy as np + + # for t in range(i): + # # avg = [v[t] for k, v in obs_recorder.items()] + # v = np.stack([v[0] for k, v in obs_recorder.items()]) + + print('\n\n\n') + o, info = env.reset() + + default_ckpt = env.vehicle.navigation.checkpoints[-1] + # for goal, navi in env.engine.goal_manager.navigations.items(): + # if navi.checkpoints[-1] == default_ckpt: + # break + # + # assert np.all(o == info["obs/goals/{}".format(goal)]) + + s = 0 + finally: + env.close() diff --git a/metadrive/examples/train_generalization_experiment.py b/metadrive/examples/train_generalization_experiment.py index f5e31c596..ce65b7040 100755 --- a/metadrive/examples/train_generalization_experiment.py +++ b/metadrive/examples/train_generalization_experiment.py @@ -3,7 +3,12 @@ in the same test set using rllib. We verified this script with ray==2.2.0. Please report to use if you find newer version of ray is not compatible with -this script. +this script. Installation guide: + + pip install ray[rllib]==2.2.0 + pip install tensorflow_probability==0.24.0 + pip install torch + """ import argparse import copy diff --git a/metadrive/manager/traffic_manager.py b/metadrive/manager/traffic_manager.py index e6bd4f1c4..20bc49b10 100644 --- a/metadrive/manager/traffic_manager.py +++ b/metadrive/manager/traffic_manager.py @@ -80,12 +80,12 @@ def before_step(self): engine = self.engine if self.mode != TrafficMode.Respawn: for v in engine.agent_manager.active_agents.values(): - ego_lane_idx = v.lane_index[:-1] - ego_road = Road(ego_lane_idx[0], ego_lane_idx[1]) - if len(self.block_triggered_vehicles) > 0 and \ - ego_road == self.block_triggered_vehicles[-1].trigger_road: - block_vehicles = self.block_triggered_vehicles.pop() - self._traffic_vehicles += list(self.get_objects(block_vehicles.vehicles).values()) + if len(self.block_triggered_vehicles) > 0: + ego_lane_idx = v.lane_index[:-1] + ego_road = Road(ego_lane_idx[0], ego_lane_idx[1]) + if ego_road == self.block_triggered_vehicles[-1].trigger_road: + block_vehicles = self.block_triggered_vehicles.pop() + self._traffic_vehicles += list(self.get_objects(block_vehicles.vehicles).values()) for v in self._traffic_vehicles: p = self.engine.get_policy(v.name) v.before_step(p.act()) @@ -266,7 +266,8 @@ def _create_vehicles_once(self, map: BaseMap, traffic_density: float) -> None: vehicle_type = self.random_vehicle_type() v_config.update(self.engine.global_config["traffic_vehicle_config"]) random_v = self.spawn_object(vehicle_type, vehicle_config=v_config) - self.add_policy(random_v.id, IDMPolicy, random_v, self.generate_seed()) + seed = self.generate_seed() + self.add_policy(random_v.id, IDMPolicy, random_v, seed) vehicles_on_block.append(random_v.name) trigger_road = block.pre_block_socket.positive_road @@ -310,8 +311,7 @@ def destroy(self) -> None: # current map # traffic vehicle list - self._traffic_vehicles = None - self.block_triggered_vehicles = None + self.block_triggered_vehicles = [] # traffic property self.mode = None diff --git a/metadrive/obs/state_obs.py b/metadrive/obs/state_obs.py index 69463a8e9..0c71e4e64 100644 --- a/metadrive/obs/state_obs.py +++ b/metadrive/obs/state_obs.py @@ -15,7 +15,7 @@ def __init__(self, config): if config["vehicle_config"]["navigation_module"]: navi_dim = config["vehicle_config"]["navigation_module"].get_navigation_info_dim() else: - navi_dim = NodeNetworkNavigation.get_navigation_info_dim() + navi_dim = 0 self.navi_dim = navi_dim super(StateObservation, self).__init__(config) @@ -56,9 +56,12 @@ def observe(self, vehicle): :param vehicle: BaseVehicle :return: Vehicle State + Navigation information """ - navi_info = vehicle.navigation.get_navi_info() ego_state = self.vehicle_state(vehicle) - ret = np.concatenate([ego_state, navi_info]) + if self.navi_dim > 0: + navi_info = vehicle.navigation.get_navi_info() + ret = np.concatenate([ego_state, navi_info]) + else: + ret = np.asarray(ego_state) return ret.astype(np.float32) def vehicle_state(self, vehicle): @@ -89,8 +92,8 @@ def vehicle_state(self, vehicle): # If the side detector is turn off, then add the distance to left and right road borders as state. lateral_to_left, lateral_to_right, = vehicle.dist_to_left_side, vehicle.dist_to_right_side - if vehicle.navigation.map: - total_width = float((vehicle.navigation.map.MAX_LANE_NUM + 1) * vehicle.navigation.map.MAX_LANE_WIDTH) + if self.engine.current_map: + total_width = float((self.engine.current_map.MAX_LANE_NUM + 1) * self.engine.current_map.MAX_LANE_WIDTH) else: total_width = 100 lateral_to_left /= total_width diff --git a/metadrive/tests/test_component/test_lane_line_detector.py b/metadrive/tests/test_component/test_lane_line_detector.py index 09d0211de..3fa091eea 100644 --- a/metadrive/tests/test_component/test_lane_line_detector.py +++ b/metadrive/tests/test_component/test_lane_line_detector.py @@ -264,132 +264,28 @@ ] pg_gt_3 = [ - 0.17000000178813934, - 0.18000000715255737, - 0.18000000715255737, - 0.1899999976158142, - 0.20000000298023224, - 0.2199999988079071, - 0.23999999463558197, - 0.27000001072883606, - 0.33000001311302185, - 0.4099999964237213, - 0.5600000023841858, - 1.0, - 1.0, - 0.550000011920929, - 0.18000000715255737, - 0.10999999940395355, - 0.07999999821186066, - 0.05999999865889549, - 0.05000000074505806, - 0.05000000074505806, - 0.03999999910593033, - 0.03999999910593033, - 0.03999999910593033, - 0.03999999910593033, - 0.029999999329447746, - 0.029999999329447746, - 0.029999999329447746, - 0.03999999910593033, - 0.03999999910593033, - 0.03999999910593033, - 0.03999999910593033, - 0.05000000074505806, - 0.05000000074505806, - 0.05999999865889549, - 0.07999999821186066, - 0.10999999940395355, - 0.18000000715255737, - 0.47999998927116394, - 0.7099999785423279, - 0.8799999952316284, - 1.0, - 0.4099999964237213, - 0.33000001311302185, - 0.27000001072883606, - 0.23999999463558197, - 0.2199999988079071, - 0.20000000298023224, - 0.1899999976158142, - 0.18000000715255737, - 0.18000000715255737, - 0.5, - 0.009999999776482582, - 0.5, - 0.5, - 0.5, - 0.0, - 0.17000000178813934, - 0.029999999329447746, - 0.03999999910593033, - 0.03999999910593033, - 0.03999999910593033, - 0.03999999910593033, - 0.05000000074505806, - 0.27000001072883606, - 0.1899999976158142, - 0.4099999964237213, - 0.10999999940395355, - 0.18000000715255737, - 0.5600000023841858, - 0.550000011920929, - 0.18000000715255737, - 0.10999999940395355, - 0.07999999821186066, - 0.05999999865889549, - 0.05000000074505806, - 0.05000000074505806, - 0.03999999910593033, - 0.03999999910593033, - 0.03999999910593033, - 0.03999999910593033, - 0.029999999329447746, - 0.029999999329447746, - 0.029999999329447746, - 0.03999999910593033, - 0.03999999910593033, - 0.03999999910593033, - 0.03999999910593033, - 0.05000000074505806, - 0.05000000074505806, - 0.05999999865889549, - 0.07999999821186066, - 0.10999999940395355, - 0.18000000715255737, - 0.47999998927116394, - 0.7099999785423279, - 0.7400000095367432, - 0.7799999713897705, - 0.07999999821186066, - 0.05999999865889549, - 0.05000000074505806, - 0.23999999463558197, - 0.12999999523162842, - 0.11999999731779099, - 0.10999999940395355, - 0.18000000715255737, - 0.18000000715255737, - 0.30000001192092896, - 0.4300000071525574, - 0.0, - 0.5, - 0.5, - 0.699999988079071, - 0.4300000071525574, - 0.0, - 0.5, - 0.5, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, + 0.17424996197223663, 0.17563384771347046, 0.179901123046875, 0.18740931153297424, 0.19884595274925232, + 0.21538487076759338, 0.23903638124465942, 0.273365318775177, 0.32519838213920593, 0.40924975275993347, + 0.5638853311538696, 1.0, 1.0, 0.5454577803611755, 0.18278196454048157, 0.11083516478538513, 0.08044067770242691, + 0.06391915678977966, 0.05373150855302811, 0.04698386788368225, 0.04233507066965103, 0.03907269984483719, + 0.03683317080140114, 0.03535793721675873, 0.034397244453430176, 0.03412747383117676, 0.034522198140621185, + 0.0353609174489975, 0.036836810410022736, 0.03908421844244003, 0.04233531653881073, 0.04698418080806732, + 0.05373169109225273, 0.06391977518796921, 0.08044077455997467, 0.11082261800765991, 0.18277062475681305, + 0.45523008704185486, 0.7089459300041199, 0.8811144232749939, 1.0, 0.40924930572509766, 0.325198233127594, + 0.2732689678668976, 0.23903624713420868, 0.21538473665714264, 0.19884586334228516, 0.1874106079339981, + 0.17990022897720337, 0.17563492059707642, 0.49999988079071045, 0.012532129883766174, 0.5, 0.5, 0.5, + 1.3244441561255371e-06, 0.17424996197223663, 0.03450844809412956, 0.035360947251319885, 0.03683680295944214, + 0.03908339887857437, 0.04233534261584282, 0.04698362201452255, 0.273365318775177, 0.19455832242965698, + 0.40924975275993347, 0.1108340248465538, 0.18278250098228455, 0.5611181855201721, 0.5454577803611755, + 0.18278196454048157, 0.11083516478538513, 0.08044067770242691, 0.06391915678977966, 0.05373150855302811, + 0.04698386788368225, 0.04233507066965103, 0.03907269984483719, 0.03683317080140114, 0.03535793721675873, + 0.034397244453430176, 0.03412747383117676, 0.034522198140621185, 0.0353609174489975, 0.036836810410022736, + 0.03908421844244003, 0.04233531653881073, 0.04698418080806732, 0.05373169109225273, 0.06391977518796921, + 0.08044077455997467, 0.11082261800765991, 0.18277062475681305, 0.45523008704185486, 0.7089459300041199, + 0.7421090602874756, 0.7847582697868347, 0.08044064044952393, 0.0639198049902916, 0.053731828927993774, + 0.23903624713420868, 0.1288066953420639, 0.11896516382694244, 0.11209990084171295, 0.17990022897720337, + 0.17563492059707642, 0.3000001609325409, 0.43000006675720215, 0.0, 0.5, 0.5, 0.699988842010498, 0.4299999475479126, + 0.0, 0.5, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ] @@ -412,7 +308,7 @@ def test_pg_map(render=False): ) try: env.reset() - env.vehicle.set_position([73, 12]) + env.agent.set_position([73, 12]) for s in range(1, 5): o, r, tm, tc, info = env.step([0, 0]) @@ -422,7 +318,7 @@ def test_pg_map(render=False): print("]") np.testing.assert_almost_equal(pg_gt_1, np.round(o, 2), decimal=2) - env.vehicle.set_position([30, 3.5]) + env.agent.set_position([30, 3.5]) for s in range(1, 5): o, r, tm, tc, info = env.step([0, 0]) print("[") @@ -431,13 +327,16 @@ def test_pg_map(render=False): print("]") np.testing.assert_almost_equal(np.array(pg_gt_2), o, decimal=2) - env.vehicle.set_position([30, 10.5]) + env.agent.set_position([30, 10.5]) for s in range(1, 5): o, r, tm, tc, info = env.step([0, 0]) print("[") - for _o in o: - print("{},".format(round(_o, 2))) + for ind, _o in enumerate(o): + print("{:.2f}, GT: {:.2f}".format(_o, pg_gt_3[ind])) print("]") + + print(o.tolist()) + np.testing.assert_almost_equal(np.array(pg_gt_3), o, decimal=2) finally: env.close() @@ -755,7 +654,7 @@ def test_nuscenes(render=False): ) try: env.reset(seed=0) - env.vehicle.set_position([-9.4, -27.2]) + env.agent.set_position([-9.4, -27.2]) for s in range(1, 5): o, r, tm, tc, info = env.step([0, 0]) @@ -766,7 +665,7 @@ def test_nuscenes(render=False): np.testing.assert_almost_equal(nuscenes_gt_1, o, decimal=3) env.reset(seed=1) - env.vehicle.set_position([79.96, -6.2]) + env.agent.set_position([79.96, -6.2]) for s in range(1, 5): o, r, tm, tc, info = env.step([0, 0]) @@ -781,4 +680,4 @@ def test_nuscenes(render=False): if __name__ == '__main__': # test_nuscenes(True) - test_pg_map(True) + test_pg_map(False) diff --git a/metadrive/utils/registry.py b/metadrive/utils/registry.py index 4411b2fce..b0f1dc87f 100644 --- a/metadrive/utils/registry.py +++ b/metadrive/utils/registry.py @@ -13,7 +13,7 @@ def _initialize_registry(): from metadrive.component.pgblock.parking_lot import ParkingLot from metadrive.component.pgblock.ramp import InRampOnStraight, OutRampOnStraight from metadrive.component.pgblock.roundabout import Roundabout - from metadrive.component.pgblock.std_intersection import StdInterSection + from metadrive.component.pgblock.std_intersection import StdInterSection, StdInterSectionWithUTurn from metadrive.component.pgblock.std_t_intersection import StdTInterSection from metadrive.component.pgblock.straight import Straight from metadrive.component.pgblock.tollgate import TollGate @@ -21,7 +21,7 @@ def _initialize_registry(): _metadrive_class_list.extend( [ Merge, Split, Curve, InFork, OutFork, ParkingLot, InRampOnStraight, OutRampOnStraight, Roundabout, - StdInterSection, StdTInterSection, Straight, TollGate, Bidirection + StdInterSection, StdTInterSection, StdInterSectionWithUTurn, Straight, TollGate, Bidirection ] ) diff --git a/setup.py b/setup.py index da5e82984..affbbbbb1 100644 --- a/setup.py +++ b/setup.py @@ -61,6 +61,7 @@ def is_win(): "shapely", "filelock", "Pygments", + "mediapy" ]