diff --git a/README.md b/README.md
index 77893134f3..cbf9b9d007 100644
--- a/README.md
+++ b/README.md
@@ -324,7 +324,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 37 | [tabmwp](https://promptpg.github.io/explore.html) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/tabmwp/tabmwp.jpeg) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/tabmwp)
env tutorial
环境指南 |
| 38 | [frozen_lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/frozen_lake/FrozenLake.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/frozen_lake)
env tutorial
环境指南 |
| 39 | [ising_model](https://github.com/mlii/mfrl/tree/master/examples/ising_model) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/ising_env/ising_env.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/ising_env)
env tutorial
[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/ising_model_zh.html) |
-| 40 | [taxi](https://www.gymlibrary.dev/environments/toy_text/taxi/) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/taxi/Taxi-v3_episode_0.gif) | dizoo link
env tutorial
环境指南 |
+| 40 | [taxi](https://www.gymlibrary.dev/environments/toy_text/taxi/) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/taxi/Taxi-v3_episode_0.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/taxi/envs)
[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/taxi.html)
[环境指南](https://di-engine-docs.readthedocs.io/zh-cn/latest/13_envs/taxi_zh.html) |
diff --git a/dizoo/taxi/config/taxi_dqn_config.py b/dizoo/taxi/config/taxi_dqn_config.py
index a168368ba9..dabbceaad7 100644
--- a/dizoo/taxi/config/taxi_dqn_config.py
+++ b/dizoo/taxi/config/taxi_dqn_config.py
@@ -1,39 +1,45 @@
from easydict import EasyDict
taxi_dqn_config = dict(
- exp_name='taxi_seed0',
+ exp_name='taxi_dqn_seed0',
env=dict(
collector_env_num=8,
- evaluator_env_num=8,
- n_evaluator_episode=10,
- max_episode_steps=300,
- env_id="Taxi-v3"
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ max_episode_steps=60,
+ env_id="Taxi-v3"
),
policy=dict(
cuda=True,
- load_path="./taxi_dqn_seed0/ckpt/ckpt_best.pth.tar",
model=dict(
- obs_shape=4,
+ obs_shape=34,
action_shape=6,
- encoder_hidden_size_list=[256, 128, 64]
+ encoder_hidden_size_list=[128, 128]
),
+ random_collect_size=5000,
nstep=3,
- discount_factor=0.98,
+ discount_factor=0.99,
learn=dict(
- update_per_collect=5,
- batch_size=128,
- learning_rate=0.001,
+ update_per_collect=10,
+ batch_size=64,
+ learning_rate=0.0001,
+ learner=dict(
+ hook=dict(
+ log_show_after_iter=1000,
+ )
+ ),
),
- collect=dict(n_sample=10),
- eval=dict(evaluator=dict(eval_freq=5, )),
+ collect=dict(n_sample=32),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
other=dict(
eps=dict(
type="linear",
- start=0.8,
- end=0.1,
- decay=10000
- ),
- replay_buffer=dict(replay_buffer_size=20000,),
+ start=1,
+ end=0.05,
+ decay=3000000
+ ),
+ replay_buffer=dict(replay_buffer_size=100000,),
),
)
)
@@ -55,4 +61,4 @@
if __name__ == "__main__":
from ding.entry import serial_pipeline
- serial_pipeline((main_config, create_config), max_env_step=5000, seed=0)
\ No newline at end of file
+ serial_pipeline((main_config, create_config), max_env_step=3000000, seed=0)
\ No newline at end of file
diff --git a/dizoo/taxi/envs/taxi_env.py b/dizoo/taxi/envs/taxi_env.py
index 8a026b6b6a..a2d5285e58 100644
--- a/dizoo/taxi/envs/taxi_env.py
+++ b/dizoo/taxi/envs/taxi_env.py
@@ -93,8 +93,8 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
if replay_path is None:
replay_path = './video'
- if not os.path.exists(replay_path):
- os.makedirs(replay_path)
+ if not os.path.exists(replay_path):
+ os.makedirs(replay_path)
self._replay_path = replay_path
self._save_replay = True
self._save_replay_count = 0
@@ -118,7 +118,11 @@ def random_action(self) -> np.ndarray:
#todo encode the state into a vector
def _encode_taxi(self, obs: np.ndarray) -> np.ndarray:
taxi_row, taxi_col, passenger_location, destination = self._env.unwrapped.decode(obs)
- return to_ndarray([taxi_row, taxi_col, passenger_location, destination])
+ encoded_obs = np.zeros(34)
+ encoded_obs[5 * taxi_row + taxi_col] = 1
+ encoded_obs[25 + passenger_location] = 1
+ encoded_obs[30 + destination] = 1
+ return to_ndarray(encoded_obs)
@property
def observation_space(self) -> Space:
diff --git a/dizoo/taxi/envs/test_taxi_env.py b/dizoo/taxi/envs/test_taxi_env.py
index 917e3f9917..7334ce4a08 100644
--- a/dizoo/taxi/envs/test_taxi_env.py
+++ b/dizoo/taxi/envs/test_taxi_env.py
@@ -16,7 +16,7 @@ def test_naive(self):
env.seed(314, dynamic_seed=False)
assert env._seed == 314
obs = env.reset()
- assert obs.shape == (4, )
+ assert obs.shape == (34, )
for _ in range(5):
env.reset()
np.random.seed(314)
@@ -32,7 +32,7 @@ def test_naive(self):
print(f"Your timestep in wrapped mode is: {timestep}")
assert isinstance(timestep.obs, np.ndarray)
assert isinstance(timestep.done, bool)
- assert timestep.obs.shape == (4, )
+ assert timestep.obs.shape == (34, )
assert timestep.reward.shape == (1, )
assert timestep.reward >= env.reward_space.low
assert timestep.reward <= env.reward_space.high