From e55201ad0696eee3d5def9df43e98b3cb7d5a25b Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 5 Nov 2024 14:58:50 +0100 Subject: [PATCH] Patch obs space and update trained agents --- CHANGELOG.md | 1 + requirements.txt | 2 +- rl-trained-agents | 2 +- rl_zoo3/enjoy.py | 3 +++ 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 937cc078c..95f56c118 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ ### New Features - Added `CrossQ` hyperparameters for SB3-contrib (@danielpalen) - Added Gymnasium v1.0 support +- `--custom-objects` in `enjoy.py` now also patches obs space (when bounds are changed) to solve "Observation spaces do not match" errors ### Bug fixes - Replaced deprecated `huggingface_hub.Repository` when pushing to Hugging Face Hub by the recommended `HfApi` (see https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http) (@cochaviz) diff --git a/requirements.txt b/requirements.txt index 13ebb4773..cda9d4521 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ gym==0.26.2 stable-baselines3[extra,tests,docs]>=2.4.0a11,<3.0 box2d-py==2.3.8 -pybullet_envs_gymnasium>=0.4.0 +pybullet_envs_gymnasium>=0.5.0 # minigrid cloudpickle>=2.2.1 # optuna plots: diff --git a/rl-trained-agents b/rl-trained-agents index ca4371d8e..5351cb204 160000 --- a/rl-trained-agents +++ b/rl-trained-agents @@ -1 +1 @@ -Subproject commit ca4371d8eef7c2eece81461f3d138d23743b2296 +Subproject commit 5351cb204619964f7bd3e2aa9672b93e0a037b23 diff --git a/rl_zoo3/enjoy.py b/rl_zoo3/enjoy.py index 4cb717a7d..55c7bfd99 100644 --- a/rl_zoo3/enjoy.py +++ b/rl_zoo3/enjoy.py @@ -184,12 +184,15 @@ def enjoy() -> None: # noqa: C901 "learning_rate": 0.0, "lr_schedule": lambda _: 0.0, "clip_range": lambda _: 0.0, + "observation_space": env.observation_space, # load models with different obs bounds } if "HerReplayBuffer" in hyperparams.get("replay_buffer_class", ""): kwargs["env"] = env model = ALGOS[algo].load(model_path, custom_objects=custom_objects, device=args.device, **kwargs) + # Uncomment to save patched file (for instance gym -> gymnasium) + # model.save(model_path) obs = env.reset() # Deterministic by default except for atari games