From 73ff16f094fe94a422532d993046b0462c6ea807 Mon Sep 17 00:00:00 2001 From: ruiheng123 <121342992+ruiheng123@users.noreply.github.com> Date: Thu, 20 Jun 2024 15:05:59 +0800 Subject: [PATCH] feature(wrh): add taxi env latest version and dqn config (#807) * update taxi env --- README.md | 2 +- dizoo/taxi/config/taxi_dqn_config.py | 46 ++++++++++++++++------------ dizoo/taxi/envs/taxi_env.py | 10 ++++-- dizoo/taxi/envs/test_taxi_env.py | 4 +-- 4 files changed, 36 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 77893134f3..cbf9b9d007 100644 --- a/README.md +++ b/README.md @@ -324,7 +324,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo` | 37 | [tabmwp](https://promptpg.github.io/explore.html) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/tabmwp/tabmwp.jpeg) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/tabmwp)
env tutorial
环境指南 | | 38 | [frozen_lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/frozen_lake/FrozenLake.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/frozen_lake)
env tutorial
环境指南 | | 39 | [ising_model](https://github.com/mlii/mfrl/tree/master/examples/ising_model) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/ising_env/ising_env.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/ising_env)
env tutorial
[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/ising_model_zh.html) | -| 40 | [taxi](https://www.gymlibrary.dev/environments/toy_text/taxi/) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/taxi/Taxi-v3_episode_0.gif) | dizoo link
env tutorial
环境指南 | +| 40 | [taxi](https://www.gymlibrary.dev/environments/toy_text/taxi/) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/taxi/Taxi-v3_episode_0.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/taxi/envs)
[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/taxi.html)
[环境指南](https://di-engine-docs.readthedocs.io/zh-cn/latest/13_envs/taxi_zh.html) | diff --git a/dizoo/taxi/config/taxi_dqn_config.py b/dizoo/taxi/config/taxi_dqn_config.py index a168368ba9..dabbceaad7 100644 --- a/dizoo/taxi/config/taxi_dqn_config.py +++ b/dizoo/taxi/config/taxi_dqn_config.py @@ -1,39 +1,45 @@ from easydict import EasyDict taxi_dqn_config = dict( - exp_name='taxi_seed0', + exp_name='taxi_dqn_seed0', env=dict( collector_env_num=8, - evaluator_env_num=8, - n_evaluator_episode=10, - max_episode_steps=300, - env_id="Taxi-v3" + evaluator_env_num=8, + n_evaluator_episode=8, + stop_value=20, + max_episode_steps=60, + env_id="Taxi-v3" ), policy=dict( cuda=True, - load_path="./taxi_dqn_seed0/ckpt/ckpt_best.pth.tar", model=dict( - obs_shape=4, + obs_shape=34, action_shape=6, - encoder_hidden_size_list=[256, 128, 64] + encoder_hidden_size_list=[128, 128] ), + random_collect_size=5000, nstep=3, - discount_factor=0.98, + discount_factor=0.99, learn=dict( - update_per_collect=5, - batch_size=128, - learning_rate=0.001, + update_per_collect=10, + batch_size=64, + learning_rate=0.0001, + learner=dict( + hook=dict( + log_show_after_iter=1000, + ) + ), ), - collect=dict(n_sample=10), - eval=dict(evaluator=dict(eval_freq=5, )), + collect=dict(n_sample=32), + eval=dict(evaluator=dict(eval_freq=1000, )), other=dict( eps=dict( type="linear", - start=0.8, - end=0.1, - decay=10000 - ), - replay_buffer=dict(replay_buffer_size=20000,), + start=1, + end=0.05, + decay=3000000 + ), + replay_buffer=dict(replay_buffer_size=100000,), ), ) ) @@ -55,4 +61,4 @@ if __name__ == "__main__": from ding.entry import serial_pipeline - serial_pipeline((main_config, create_config), max_env_step=5000, seed=0) \ No newline at end of file + serial_pipeline((main_config, create_config), max_env_step=3000000, seed=0) \ No newline at end of file diff --git a/dizoo/taxi/envs/taxi_env.py b/dizoo/taxi/envs/taxi_env.py index 8a026b6b6a..a2d5285e58 100644 --- a/dizoo/taxi/envs/taxi_env.py +++ b/dizoo/taxi/envs/taxi_env.py @@ -93,8 +93,8 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: def enable_save_replay(self, replay_path: Optional[str] = None) -> None: if replay_path is None: replay_path = './video' - if not os.path.exists(replay_path): - os.makedirs(replay_path) + if not os.path.exists(replay_path): + os.makedirs(replay_path) self._replay_path = replay_path self._save_replay = True self._save_replay_count = 0 @@ -118,7 +118,11 @@ def random_action(self) -> np.ndarray: #todo encode the state into a vector def _encode_taxi(self, obs: np.ndarray) -> np.ndarray: taxi_row, taxi_col, passenger_location, destination = self._env.unwrapped.decode(obs) - return to_ndarray([taxi_row, taxi_col, passenger_location, destination]) + encoded_obs = np.zeros(34) + encoded_obs[5 * taxi_row + taxi_col] = 1 + encoded_obs[25 + passenger_location] = 1 + encoded_obs[30 + destination] = 1 + return to_ndarray(encoded_obs) @property def observation_space(self) -> Space: diff --git a/dizoo/taxi/envs/test_taxi_env.py b/dizoo/taxi/envs/test_taxi_env.py index 917e3f9917..7334ce4a08 100644 --- a/dizoo/taxi/envs/test_taxi_env.py +++ b/dizoo/taxi/envs/test_taxi_env.py @@ -16,7 +16,7 @@ def test_naive(self): env.seed(314, dynamic_seed=False) assert env._seed == 314 obs = env.reset() - assert obs.shape == (4, ) + assert obs.shape == (34, ) for _ in range(5): env.reset() np.random.seed(314) @@ -32,7 +32,7 @@ def test_naive(self): print(f"Your timestep in wrapped mode is: {timestep}") assert isinstance(timestep.obs, np.ndarray) assert isinstance(timestep.done, bool) - assert timestep.obs.shape == (4, ) + assert timestep.obs.shape == (34, ) assert timestep.reward.shape == (1, ) assert timestep.reward >= env.reward_space.low assert timestep.reward <= env.reward_space.high