From f4c5b1e5e2f3ca95f1e2c37634b6252a835aa624 Mon Sep 17 00:00:00 2001 From: Corentin <111868204+corentinlger@users.noreply.github.com> Date: Sun, 24 Sep 2023 12:36:52 +0200 Subject: [PATCH] Fix check_env for Sequence observation space (#1690) * Fix Sequence obs env_checker * Fix Sequence obs env_checker * Add test : env_checker for Sequence obs * Add test : env_checker for Sequence obs * Cleanup and improve env checker messages --------- Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 25 +++++++++++++------------ stable_baselines3/common/env_checker.py | 25 +++++++++++++++++++++---- stable_baselines3/version.txt | 2 +- tests/test_env_checker.py | 22 ++++++++++++++++++++++ 4 files changed, 57 insertions(+), 17 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 7204bf6d9..a0a29b0e3 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.2.0a3 (WIP) +Release 2.2.0a4 (WIP) -------------------------- Breaking Changes: @@ -12,16 +12,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - -`SB3-Contrib`_ -^^^^^^^^^^^^^^ - -`RL Zoo`_ -^^^^^^^^^ - -`SBX`_ -^^^^^^^^^ -- Added ``DDPG`` and ``TD3`` +- Improved error message of the ``env_checker`` for env wrongly detected as GoalEnv (``compute_reward()`` is defined) Bug Fixes: ^^^^^^^^^^ @@ -33,7 +24,17 @@ Bug Fixes: - Fixed replay buffer device after loading in OffPolicyAlgorithm (@PatrickHelm) - Fixed ``render_mode`` which was not properly loaded when using ``VecNormalize.load()`` - Fixed success reward dtype in ``SimpleMultiObsEnv`` (@NixGD) +- Fixed check_env for Sequence observation space (@corentinlger) +`SB3-Contrib`_ +^^^^^^^^^^^^^^ + +`RL Zoo`_ +^^^^^^^^^ + +`SBX`_ +^^^^^^^^^ +- Added ``DDPG`` and ``TD3`` Deprecations: ^^^^^^^^^^^^^ @@ -1459,4 +1460,4 @@ And all the contributors: @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto -@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm +@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 8b8da7f44..dc465a1d6 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -80,7 +80,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act if isinstance(observation_space, spaces.Tuple): warnings.warn( - "The observation space is a Tuple," + "The observation space is a Tuple, " "this is currently not supported by Stable Baselines3. " "However, you can convert it to a Dict observation space " "(cf. https://gymnasium.farama.org/api/spaces/composite/#dict). " @@ -93,6 +93,13 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act "You can use a wrapper or update your observation space." ) + if isinstance(observation_space, spaces.Sequence): + warnings.warn( + "Sequence observation space is not supported by Stable-Baselines3. " + "You can pad your observation to have a fixed size instead.\n" + "Note: The checks for returned values are skipped." + ) + if isinstance(action_space, spaces.Discrete) and action_space.start != 0: warnings.warn( "Discrete action space with a non-zero start is not supported by Stable-Baselines3. " @@ -347,9 +354,15 @@ def _check_spaces(env: gym.Env) -> None: assert isinstance(env.action_space, spaces.Space), f"The action space must inherit from gymnasium.spaces ({gym_spaces})" if _is_goal_env(env): - assert isinstance( - env.observation_space, spaces.Dict - ), "Goal conditioned envs (previously gym.GoalEnv) require the observation space to be gymnasium.spaces.Dict" + print( + "We detected your env to be a GoalEnv because `env.compute_reward()` was defined.\n" + "If it's not the case, please rename `env.compute_reward()` to something else to avoid False positives." + ) + assert isinstance(env.observation_space, spaces.Dict), ( + "Goal conditioned envs (previously gym.GoalEnv) require the observation space to be gymnasium.spaces.Dict.\n" + "Note: if your env is not a GoalEnv, please rename `env.compute_reward()` " + "to something else to avoid False positive." + ) # Check render cannot be covered by CI @@ -440,6 +453,10 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) - f"Your action space has dtype {action_space.dtype}, we recommend using np.float32 to avoid cast errors." ) + # If Sequence observation space, do not check the observation any further + if isinstance(observation_space, spaces.Sequence): + return + # ============ Check the returned values =============== _check_returned_values(env, observation_space, action_space) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index aaceff257..ddcf0926b 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.2.0a3 +2.2.0a4 diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index 87cc177b7..62dd6ffa6 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -37,6 +37,28 @@ def test_check_env_dict_action(): check_env(env=test_env, warn=True) +class SequenceObservationEnv(gym.Env): + metadata = {"render_modes": [], "render_fps": 2} + + def __init__(self, render_mode=None): + self.observation_space = spaces.Sequence(spaces.Discrete(8)) + self.action_space = spaces.Discrete(4) + + def reset(self, seed=None, options=None): + super().reset(seed=seed) + return self.observation_space.sample(), {} + + def step(self, action): + return self.observation_space.sample(), 1.0, False, False, {} + + +def test_check_env_sequence_obs(): + test_env = SequenceObservationEnv() + + with pytest.warns(Warning, match="Sequence.*not supported"): + check_env(env=test_env, warn=True) + + @pytest.mark.parametrize( "obs_tuple", [