diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cb9055266..d34a93c9a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,12 @@ jobs: strategy: matrix: python-version: ["3.8", "3.9", "3.10", "3.11"] - + include: + # Default version + - gymnasium-version: "1.0.0" + # Add a new config to test gym<1.0 + - python-version: "3.10" + gymnasium-version: "0.29.1" steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -37,15 +42,14 @@ jobs: # See https://github.com/astral-sh/uv/issues/1497 uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu - # Install Atari Roms - uv pip install --system autorom - wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 - base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz - AutoROM --accept-license --source-file Roms.tar.gz - - uv pip install --system .[extra_no_roms,tests,docs] + uv pip install --system .[extra,tests,docs] # Use headless version uv pip install --system opencv-python-headless + - name: Install specific version of gym + run: | + uv pip install --system gymnasium==${{ matrix.gymnasium-version }} + # Only run for python 3.10, downgrade gym to 0.29.1 + if: matrix.gymnasium-version != '1.0.0' - name: Lint with ruff run: | make lint diff --git a/docs/conda_env.yml b/docs/conda_env.yml index e025a57e1..ac065b3b9 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -8,7 +8,7 @@ dependencies: - python=3.11 - pytorch=2.5.0=py3.11_cpu_0 - pip: - - gymnasium>=0.28.1,<0.30 + - gymnasium>=0.29.1,<1.1.0 - cloudpickle - opencv-python-headless - pandas diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b32cd7ce1..cf2a2a520 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,10 +3,10 @@ Changelog ========== -Release 2.4.0a10 (WIP) +Release 2.4.0a11 (WIP) -------------------------- -**New algorithm: CrossQ in SB3 Contrib** +**New algorithm: CrossQ in SB3 Contrib, Gymnasium v1.0 support** .. note:: @@ -24,12 +24,14 @@ Release 2.4.0a10 (WIP) Breaking Changes: ^^^^^^^^^^^^^^^^^ +- Increase minimum required version of Gymnasium to 0.29.1 New Features: ^^^^^^^^^^^^^ - Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ) - Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle) - Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces +- Added support for Gymnasium v1.0 Bug Fixes: ^^^^^^^^^^ @@ -69,6 +71,7 @@ Others: - Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy`` - Switched to uv to download packages faster on GitHub CI - Updated dependencies for read the doc +- Removed unnecessary ``copy_obs_dict`` method for ``SubprocVecEnv``, remove the use of ordered dict and rename ``flatten_obs`` to ``stack_obs`` Bug Fixes: ^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index dd435a33e..1fd1a1890 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,6 @@ ignore = ["B028", "RUF013"] # ClassVar, implicit optional check not needed for tests "./tests/*.py" = ["RUF012", "RUF013"] - [tool.ruff.lint.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 15 diff --git a/setup.py b/setup.py index 9d56dfd77..52f626462 100644 --- a/setup.py +++ b/setup.py @@ -70,37 +70,13 @@ """ # noqa:E501 -# Atari Games download is sometimes problematic: -# https://github.com/Farama-Foundation/AutoROM/issues/39 -# That's why we define extra packages without it. -extra_no_roms = [ - # For render - "opencv-python", - "pygame", - # Tensorboard support - "tensorboard>=2.9.1", - # Checking memory taken by replay buffer - "psutil", - # For progress bar callback - "tqdm", - "rich", - # For atari games, - "shimmy[atari]~=1.3.0", - "pillow", -] - -extra_packages = extra_no_roms + [ # noqa: RUF005 - # For atari roms, - "autorom[accept-rom-license]~=0.6.1", -] - setup( name="stable_baselines3", packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gymnasium>=0.28.1,<0.30", + "gymnasium>=0.29.1,<1.1.0", "numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302 "torch>=1.13", # For saving models @@ -133,8 +109,21 @@ # Copy button for code snippets "sphinx_copybutton", ], - "extra": extra_packages, - "extra_no_roms": extra_no_roms, + "extra": [ + # For render + "opencv-python", + "pygame", + # Tensorboard support + "tensorboard>=2.9.1", + # Checking memory taken by replay buffer + "psutil", + # For progress bar callback + "tqdm", + "rich", + # For atari games, + "ale-py>=0.9.0", + "pillow", + ], }, description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.", author="Antonin Raffin", diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 15ecfb681..5625e2453 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -8,7 +8,7 @@ from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn from stable_baselines3.common.vec_env.patch_gym import _patch_env -from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info +from stable_baselines3.common.vec_env.util import dict_to_obs, obs_space_info class DummyVecEnv(VecEnv): @@ -110,12 +110,12 @@ def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None: self.buf_obs[key][env_idx] = obs[key] # type: ignore[call-overload] def _obs_from_buf(self) -> VecEnvObs: - return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs)) + return dict_to_obs(self.observation_space, deepcopy(self.buf_obs)) def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: """Return attribute from vectorized environment (see base class).""" target_envs = self._get_target_envs(indices) - return [getattr(env_i, attr_name) for env_i in target_envs] + return [env_i.get_wrapper_attr(attr_name) for env_i in target_envs] def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: """Set attribute inside vectorized environments (see base class).""" @@ -126,7 +126,7 @@ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: """Call instance methods of vectorized environments.""" target_envs = self._get_target_envs(indices) - return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs] + return [env_i.get_wrapper_attr(method_name)(*method_args, **method_kwargs) for env_i in target_envs] def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: """Check if worker environments are wrapped with a given wrapper""" diff --git a/stable_baselines3/common/vec_env/patch_gym.py b/stable_baselines3/common/vec_env/patch_gym.py index 6ba655ebf..874809a03 100644 --- a/stable_baselines3/common/vec_env/patch_gym.py +++ b/stable_baselines3/common/vec_env/patch_gym.py @@ -43,7 +43,7 @@ def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma "Missing shimmy installation. You provided an OpenAI Gym environment. " "Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. " "In order to use OpenAI Gym environments with SB3, you need to " - "install shimmy (`pip install 'shimmy>=0.2.1'`)." + "install shimmy (`pip install 'shimmy>=2.0'`)." ) from e warnings.warn( diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index c598c735a..a606a7cb9 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -1,6 +1,5 @@ import multiprocessing as mp import warnings -from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union import gymnasium as gym @@ -54,10 +53,10 @@ def _worker( elif cmd == "get_spaces": remote.send((env.observation_space, env.action_space)) elif cmd == "env_method": - method = getattr(env, data[0]) + method = env.get_wrapper_attr(data[0]) remote.send(method(*data[1], **data[2])) elif cmd == "get_attr": - remote.send(getattr(env, data)) + remote.send(env.get_wrapper_attr(data)) elif cmd == "set_attr": remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value] elif cmd == "is_wrapped": @@ -129,7 +128,7 @@ def step_wait(self) -> VecEnvStepReturn: results = [remote.recv() for remote in self.remotes] self.waiting = False obs, rews, dones, infos, self.reset_infos = zip(*results) # type: ignore[assignment] - return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value] + return _stack_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value] def reset(self) -> VecEnvObs: for env_idx, remote in enumerate(self.remotes): @@ -139,7 +138,7 @@ def reset(self) -> VecEnvObs: # Seeds and options are only used once self._reset_seeds() self._reset_options() - return _flatten_obs(obs, self.observation_space) + return _stack_obs(obs, self.observation_space) def close(self) -> None: if self.closed: @@ -206,27 +205,28 @@ def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]: return [self.remotes[i] for i in indices] -def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs: +def _stack_obs(obs_list: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs: """ - Flatten observations, depending on the observation space. + Stack observations (convert from a list of single env obs to a stack of obs), + depending on the observation space. :param obs: observations. A list or tuple of observations, one per environment. Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays. - :return: flattened observations. - A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays. + :return: Concatenated observations. + A NumPy array or a dict or tuple of stacked numpy arrays. Each NumPy array has the environment index as its first axis. """ - assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" - assert len(obs) > 0, "need observations from at least one environment" + assert isinstance(obs_list, (list, tuple)), "expected list or tuple of observations per environment" + assert len(obs_list) > 0, "need observations from at least one environment" if isinstance(space, spaces.Dict): - assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces" - assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" - return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) + assert isinstance(space.spaces, dict), "Dict space must have ordered subspaces" + assert isinstance(obs_list[0], dict), "non-dict observation for environment with Dict observation space" + return {key: np.stack([single_obs[key] for single_obs in obs_list]) for key in space.spaces.keys()} # type: ignore[call-overload] elif isinstance(space, spaces.Tuple): - assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" + assert isinstance(obs_list[0], tuple), "non-tuple observation for environment with Tuple observation space" obs_len = len(space.spaces) - return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) # type: ignore[index] + return tuple(np.stack([single_obs[i] for single_obs in obs_list]) for i in range(obs_len)) # type: ignore[index] else: - return np.stack(obs) # type: ignore[arg-type] + return np.stack(obs_list) # type: ignore[arg-type] diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 855f50edc..6ea04f6ab 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -2,7 +2,6 @@ Helpers for dealing with vectorized environments. """ -from collections import OrderedDict from typing import Any, Dict, List, Tuple import numpy as np @@ -12,17 +11,6 @@ from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs -def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: - """ - Deep-copy a dict of numpy arrays. - - :param obs: a dict of numpy arrays. - :return: a dict of copied numpy arrays. - """ - assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'" - return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) - - def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs: """ Convert an internal representation raw_obs into the appropriate type @@ -60,13 +48,13 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[ """ check_for_nested_spaces(obs_space) if isinstance(obs_space, spaces.Dict): - assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces" + assert isinstance(obs_space.spaces, dict), "Dict space must have ordered subspaces" subspaces = obs_space.spaces elif isinstance(obs_space, spaces.Tuple): - subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment] + subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment,misc] else: assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'" - subspaces = {None: obs_space} # type: ignore[assignment] + subspaces = {None: obs_space} # type: ignore[assignment,dict-item] keys = [] shapes = {} dtypes = {} @@ -74,4 +62,4 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[ keys.append(key) shapes[key] = box.shape dtypes[key] = box.dtype - return keys, shapes, dtypes + return keys, shapes, dtypes # type: ignore[return-value] diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index 52faebd1f..e586f94ab 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -1,7 +1,9 @@ import os -from typing import Callable +import os.path +from typing import Callable, List -from gymnasium.wrappers.monitoring import video_recorder +import numpy as np +from gymnasium import error, logger from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv @@ -13,6 +15,11 @@ class VecVideoRecorder(VecEnvWrapper): Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video. It requires ffmpeg or avconv to be installed on the machine. + Note: for now it only allows to record one video and all videos + must have at least two frames. + + The video recorder code was adapted from Gymnasium v1.0. + :param venv: :param video_folder: Where to save videos :param record_video_trigger: Function that defines when to start recording. @@ -22,8 +29,6 @@ class VecVideoRecorder(VecEnvWrapper): :param name_prefix: Prefix to the video name """ - video_recorder: video_recorder.VideoRecorder - def __init__( self, venv: VecEnv, @@ -51,6 +56,8 @@ def __init__( self.env.metadata = metadata assert self.env.render_mode == "rgb_array", f"The render_mode must be 'rgb_array', not {self.env.render_mode}" + self.frames_per_sec = self.env.metadata.get("render_fps", 30) + self.record_video_trigger = record_video_trigger self.video_folder = os.path.abspath(video_folder) # Create output folder if needed @@ -60,54 +67,88 @@ def __init__( self.step_id = 0 self.video_length = video_length + self.video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}.mp4" + self.video_path = os.path.join(self.video_folder, self.video_name) + self.recording = False - self.recorded_frames = 0 + self.recorded_frames: list[np.ndarray] = [] + + try: + import moviepy # noqa: F401 + except ImportError as e: + raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install 'gymnasium[other]'`") from e def reset(self) -> VecEnvObs: obs = self.venv.reset() - self.start_video_recorder() + if self._video_enabled(): + self._start_video_recorder() return obs - def start_video_recorder(self) -> None: - self.close_video_recorder() - - video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}" - base_path = os.path.join(self.video_folder, video_name) - self.video_recorder = video_recorder.VideoRecorder( - env=self.env, base_path=base_path, metadata={"step_id": self.step_id} - ) - - self.video_recorder.capture_frame() - self.recorded_frames = 1 - self.recording = True + def _start_video_recorder(self) -> None: + self._start_recording() + self._capture_frame() def _video_enabled(self) -> bool: return self.record_video_trigger(self.step_id) def step_wait(self) -> VecEnvStepReturn: - obs, rews, dones, infos = self.venv.step_wait() + obs, rewards, dones, infos = self.venv.step_wait() self.step_id += 1 if self.recording: - self.video_recorder.capture_frame() - self.recorded_frames += 1 - if self.recorded_frames > self.video_length: - print(f"Saving video to {self.video_recorder.path}") - self.close_video_recorder() + self._capture_frame() + if len(self.recorded_frames) > self.video_length: + print(f"Saving video to {self.video_path}") + self._stop_recording() elif self._video_enabled(): - self.start_video_recorder() + self._start_video_recorder() - return obs, rews, dones, infos + return obs, rewards, dones, infos - def close_video_recorder(self) -> None: - if self.recording: - self.video_recorder.close() - self.recording = False - self.recorded_frames = 1 + def _capture_frame(self) -> None: + assert self.recording, "Cannot capture a frame, recording wasn't started." + + frame = self.env.render() + if isinstance(frame, List): + frame = frame[-1] + + if isinstance(frame, np.ndarray): + self.recorded_frames.append(frame) + else: + self._stop_recording() + logger.warn( + f"Recording stopped: expected type of frame returned by render to be a numpy array, got instead {type(frame)}." + ) def close(self) -> None: + """Closes the wrapper then the video recorder.""" VecEnvWrapper.close(self) - self.close_video_recorder() + if self.recording: + self._stop_recording() + + def _start_recording(self) -> None: + """Start a new recording. If it is already recording, stops the current recording before starting the new one.""" + if self.recording: + self._stop_recording() + + self.recording = True + + def _stop_recording(self) -> None: + """Stop current recording and saves the video.""" + assert self.recording, "_stop_recording was called, but no recording was started" + + if len(self.recorded_frames) == 0: + logger.warn("Ignored saving a video as there were zero frames to save.") + else: + from moviepy.video.io.ImageSequenceClip import ImageSequenceClip + + clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec) + clip.write_videofile(self.video_path) + + self.recorded_frames = [] + self.recording = False - def __del__(self): - self.close_video_recorder() + def __del__(self) -> None: + """Warn the user in case last video wasn't saved.""" + if len(self.recorded_frames) > 0: + logger.warn("Unable to save last video! Did you call close()?") diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 852a32b3f..d5cafdb5a 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a10 +2.4.0a11 diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index f093e47e7..8049c6887 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -117,12 +117,11 @@ def test_consistency(model_class): """ use_discrete_actions = model_class == DQN dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True) + dict_env.seed(10) dict_env = gym.wrappers.TimeLimit(dict_env, 100) env = gym.wrappers.FlattenObservation(dict_env) - dict_env.seed(10) obs, _ = dict_env.reset() - kwargs = {} n_steps = 256 if model_class in {A2C, PPO}: diff --git a/tests/test_gae.py b/tests/test_gae.py index 83b95a4c0..bb674cffa 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -73,7 +73,7 @@ def _on_rollout_end(self): buffer = self.model.rollout_buffer rollout_size = buffer.size() - max_steps = self.training_env.envs[0].max_steps + max_steps = self.training_env.envs[0].get_wrapper_attr("max_steps") gamma = self.model.gamma gae_lambda = self.model.gae_lambda value = self.model.policy.constant_value diff --git a/tests/test_logger.py b/tests/test_logger.py index bc18bf2ce..02d36b306 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -592,6 +592,7 @@ def test_rollout_success_rate_onpolicy_algo(tmp_path): """ STATS_WINDOW_SIZE = 10 + # Add dummy successes with 0.3, 0.5 and 0.8 success_rate of length STATS_WINDOW_SIZE dummy_successes = [ [True] * 3 + [False] * 7, @@ -603,16 +604,17 @@ def test_rollout_success_rate_onpolicy_algo(tmp_path): # Monitor the env to track the success info monitor_file = str(tmp_path / "monitor.csv") env = Monitor(DummySuccessEnv(dummy_successes, ep_steps), filename=monitor_file, info_keywords=("is_success",)) + steps_per_log = env.unwrapped.steps_per_log # Equip the model of a custom logger to check the success_rate info - model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=env.steps_per_log, verbose=1) + model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=steps_per_log, verbose=1) logger = InMemoryLogger() model.set_logger(logger) # Make the model learn and check that the success rate corresponds to the ratio of dummy successes - model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.3 - model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.5 - model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.8 diff --git a/tests/test_utils.py b/tests/test_utils.py index 81f134168..bb2ebd067 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,7 @@ import os import shutil +import ale_py import gymnasium as gym import numpy as np import pytest @@ -24,6 +25,8 @@ ) from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv +gym.register_envs(ale_py) + @pytest.mark.parametrize("env_id", ["CartPole-v1", lambda: gym.make("CartPole-v1")]) @pytest.mark.parametrize("n_envs", [1, 2]) diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index a9516ae25..3aa52762d 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -307,7 +307,7 @@ def test_vecenv_dict_spaces(vec_env_class): space = spaces.Dict(SPACES) def obs_assert(obs): - assert isinstance(obs, collections.OrderedDict) + assert isinstance(obs, dict) assert obs.keys() == space.spaces.keys() for key, values in obs.items(): check_vecenv_obs(values, space.spaces[key])