From 08e5f9a8f41e158f4a6bbdf53bb594308b6472b2 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Tue, 13 Feb 2024 17:09:42 +0000 Subject: [PATCH 01/24] Update Gymnasium to v1.0.0a1 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 763a6a376..76edf2b68 100644 --- a/setup.py +++ b/setup.py @@ -100,7 +100,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gymnasium>=0.28.1,<0.30", + "gymnasium==1.0.0a1", "numpy>=1.20", "torch>=1.13", # For saving models From f73c08e20a5146b4309786e0f3b51d4fe5b13db6 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Tue, 13 Feb 2024 17:40:35 +0000 Subject: [PATCH 02/24] Comment out `gymnasium.wrappers.monitor` (todo update to VideoRecord) --- stable_baselines3/common/vec_env/vec_video_recorder.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index 52faebd1f..1c4e6b88b 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -1,8 +1,6 @@ import os from typing import Callable -from gymnasium.wrappers.monitoring import video_recorder - from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv @@ -22,7 +20,7 @@ class VecVideoRecorder(VecEnvWrapper): :param name_prefix: Prefix to the video name """ - video_recorder: video_recorder.VideoRecorder + # video_recorder: video_recorder.VideoRecorder def __init__( self, @@ -73,9 +71,9 @@ def start_video_recorder(self) -> None: video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}" base_path = os.path.join(self.video_folder, video_name) - self.video_recorder = video_recorder.VideoRecorder( - env=self.env, base_path=base_path, metadata={"step_id": self.step_id} - ) + # self.video_recorder = video_recorder.VideoRecorder( + # env=self.env, base_path=base_path, metadata={"step_id": self.step_id} + # ) self.video_recorder.capture_frame() self.recorded_frames = 1 From 08d3ac92d92fe80da5ace11b2dad6772cc283f6b Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Tue, 13 Feb 2024 17:40:57 +0000 Subject: [PATCH 03/24] Fix ruff warnings --- pyproject.toml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1195687f4..95c80c9c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,20 +4,19 @@ line-length = 127 # Assume Python 3.8 target-version = "py38" # See https://beta.ruff.rs/docs/rules/ -select = ["E", "F", "B", "UP", "C90", "RUF"] +lint.select = ["E", "F", "B", "UP", "C90", "RUF"] # B028: Ignore explicit stacklevel` # RUF013: Too many false positives (implicit optional) -ignore = ["B028", "RUF013"] +lint.ignore = ["B028", "RUF013"] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Default implementation in abstract methods "./stable_baselines3/common/callbacks.py"= ["B027"] "./stable_baselines3/common/noise.py"= ["B027"] # ClassVar, implicit optional check not needed for tests "./tests/*.py"= ["RUF012", "RUF013"] - -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 15 From eb55500f355d35385be4a8b5b0a866586f732606 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Tue, 13 Feb 2024 18:32:27 +0000 Subject: [PATCH 04/24] Register Atari envs --- tests/test_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index 4cc8b7e9f..08cdcdf78 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -24,6 +24,10 @@ ) from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv +# a hack to get atari environment registered for 1.0.0 alpha 1 +from shimmy import registration +registration._register_atari_envs() + @pytest.mark.parametrize("env_id", ["CartPole-v1", lambda: gym.make("CartPole-v1")]) @pytest.mark.parametrize("n_envs", [1, 2]) From 686d1a0c713ace895f38e063db81437e4189f690 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Tue, 13 Feb 2024 18:34:00 +0000 Subject: [PATCH 05/24] Update `getattr` to `Env.get_wrapper_attr` --- stable_baselines3/common/vec_env/dummy_vec_env.py | 4 ++-- stable_baselines3/common/vec_env/subproc_vec_env.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 15ecfb681..a37d3f254 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -115,7 +115,7 @@ def _obs_from_buf(self) -> VecEnvObs: def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: """Return attribute from vectorized environment (see base class).""" target_envs = self._get_target_envs(indices) - return [getattr(env_i, attr_name) for env_i in target_envs] + return [env_i.get_wrapper_attr(attr_name) for env_i in target_envs] def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: """Set attribute inside vectorized environments (see base class).""" @@ -126,7 +126,7 @@ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: """Call instance methods of vectorized environments.""" target_envs = self._get_target_envs(indices) - return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs] + return [env_i.get_wrapper_attr(method_name)(*method_args, **method_kwargs) for env_i in target_envs] def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: """Check if worker environments are wrapped with a given wrapper""" diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index c598c735a..a29bd4dc9 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -54,10 +54,10 @@ def _worker( elif cmd == "get_spaces": remote.send((env.observation_space, env.action_space)) elif cmd == "env_method": - method = getattr(env, data[0]) + method = env.get_wrapper_attr(data[0]) remote.send(method(*data[1], **data[2])) elif cmd == "get_attr": - remote.send(getattr(env, data)) + remote.send(env.get_wrapper_attr(data)) elif cmd == "set_attr": remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value] elif cmd == "is_wrapped": From da48aedf0ee2d2bf4124f4d01c6c821bf8a22dbb Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Tue, 13 Feb 2024 18:39:26 +0000 Subject: [PATCH 06/24] Reorder imports --- tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 08cdcdf78..63eacafba 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,6 +6,7 @@ import pytest import torch as th from gymnasium import spaces +from shimmy import registration import stable_baselines3 as sb3 from stable_baselines3 import A2C @@ -25,7 +26,6 @@ from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv # a hack to get atari environment registered for 1.0.0 alpha 1 -from shimmy import registration registration._register_atari_envs() From b063f941ad66a6d83cbb38cdad00a0f1f30c4962 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Tue, 13 Feb 2024 18:54:29 +0000 Subject: [PATCH 07/24] Fix `seed` order --- tests/test_dict_env.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 14777452e..102a07892 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -117,12 +117,11 @@ def test_consistency(model_class): """ use_discrete_actions = model_class == DQN dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True) + dict_env.seed(10) dict_env = gym.wrappers.TimeLimit(dict_env, 100) env = gym.wrappers.FlattenObservation(dict_env) - dict_env.seed(10) obs, _ = dict_env.reset() - kwargs = {} n_steps = 256 if model_class in {A2C, PPO}: From 6e11f934666b799152f6891a18fb2811bcc8391b Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Tue, 13 Feb 2024 18:56:56 +0000 Subject: [PATCH 08/24] Fix collecting `max_steps` --- tests/test_gae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_gae.py b/tests/test_gae.py index 83b95a4c0..bb674cffa 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -73,7 +73,7 @@ def _on_rollout_end(self): buffer = self.model.rollout_buffer rollout_size = buffer.size() - max_steps = self.training_env.envs[0].max_steps + max_steps = self.training_env.envs[0].get_wrapper_attr("max_steps") gamma = self.model.gamma gae_lambda = self.model.gae_lambda value = self.model.policy.constant_value From 39f09007c9e7e017abf9e0db7923b08c0f37a624 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Wed, 3 Apr 2024 22:26:09 +0100 Subject: [PATCH 09/24] Copy and paste video recorder to prevent the need to rewrite the vec vide recorder wrapper --- .../common/vec_env/vec_video_recorder.py | 177 +++++++++++++++++- 1 file changed, 172 insertions(+), 5 deletions(-) diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index 1c4e6b88b..76804bb41 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -1,11 +1,180 @@ +import json import os -from typing import Callable +import os.path +import tempfile +from typing import Callable, List, Optional + +import numpy as np +from gymnasium import error, logger from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv +# This is copy and pasted from Gymnasium v0.26.1 +class VideoRecorder: + """VideoRecorder renders a nice movie of a rollout, frame by frame. + + It comes with an ``enabled`` option, so you can still use the same code on episodes where you don't want to record video. + + Note: + You are responsible for calling :meth:`close` on a created VideoRecorder, or else you may leak an encoder process. + """ + + def __init__( + self, + env, + path: Optional[str] = None, + metadata: Optional[dict] = None, + enabled: bool = True, + base_path: Optional[str] = None, + ): + """Video recorder renders a nice movie of a rollout, frame by frame. + + Args: + env (Env): Environment to take video of. + path (Optional[str]): Path to the video file; will be randomly chosen if omitted. + metadata (Optional[dict]): Contents to save to the metadata file. + enabled (bool): Whether to actually record video, or just no-op (for convenience) + base_path (Optional[str]): Alternatively, path to the video file without extension, which will be added. + + Raises: + Error: You can pass at most one of `path` or `base_path` + Error: Invalid path given that must have a particular file extension + """ + try: + # check that moviepy is now installed + import moviepy # noqa: F401 + except ImportError as e: + raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install moviepy`") from e + + self._async = env.metadata.get("semantics.async") + self.enabled = enabled + self._closed = False + + self.render_history: list[np.ndarray] = [] + self.env = env + + self.render_mode = env.render_mode + + if "rgb_array_list" != self.render_mode and "rgb_array" != self.render_mode: + logger.warn( + f"Disabling video recorder because environment {env} was not initialized with any compatible video " + "mode between `rgb_array` and `rgb_array_list`" + ) + # Disable since the environment has not been initialized with a compatible `render_mode` + self.enabled = False + + # Don't bother setting anything else if not enabled + if not self.enabled: + return + + if path is not None and base_path is not None: + raise error.Error("You can pass at most one of `path` or `base_path`.") + + required_ext = ".mp4" + if path is None: + if base_path is not None: + # Base path given, append ext + path = base_path + required_ext + else: + # Otherwise, just generate a unique filename + with tempfile.NamedTemporaryFile(suffix=required_ext) as f: + path = f.name + self.path = path + + path_base, actual_ext = os.path.splitext(self.path) + + if actual_ext != required_ext: + raise error.Error(f"Invalid path given: {self.path} -- must have file extension {required_ext}.") + + self.frames_per_sec = env.metadata.get("render_fps", 30) + + self.broken = False + + # Dump metadata + self.metadata = metadata or {} + self.metadata["content_type"] = "video/mp4" + self.metadata_path = f"{path_base}.meta.json" + self.write_metadata() + + logger.info(f"Starting new video recorder writing to {self.path}") + self.recorded_frames: list[np.ndarray] = [] + + @property + def functional(self): + """Returns if the video recorder is functional, is enabled and not broken.""" + return self.enabled and not self.broken + + def capture_frame(self): + """Render the given `env` and add the resulting frame to the video.""" + frame = self.env.render() + if isinstance(frame, List): + self.render_history += frame + frame = frame[-1] + + if not self.functional: + return + if self._closed: + logger.warn("The video recorder has been closed and no frames will be captured anymore.") + return + logger.debug("Capturing video frame: path=%s", self.path) + + if frame is None: + if self._async: + return + else: + # Indicates a bug in the environment: don't want to raise + # an error here. + logger.warn( + "Env returned None on `render()`. Disabling further rendering for video recorder by marking as " + f"disabled: path={self.path} metadata_path={self.metadata_path}" + ) + self.broken = True + else: + self.recorded_frames.append(frame) + + def close(self): + """Flush all data to disk and close any open frame encoders.""" + if not self.enabled or self._closed: + return + + # First close the environment + self.env.close() + + # Close the encoder + if len(self.recorded_frames) > 0: + try: + from moviepy.video.io.ImageSequenceClip import ImageSequenceClip + except ImportError as e: + raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install moviepy`") from e + + logger.debug(f"Closing video encoder: path={self.path}") + clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec) + clip.write_videofile(self.path) + else: + # No frames captured. Set metadata. + if self.metadata is None: + self.metadata = {} + self.metadata["empty"] = True + + self.write_metadata() + + # Stop tracking this for autoclose + self._closed = True + + def write_metadata(self): + """Writes metadata to metadata path.""" + with open(self.metadata_path, "w") as f: + json.dump(self.metadata, f) + + def __del__(self): + """Closes the environment correctly when the recorder is deleted.""" + # Make sure we've closed up shop when garbage collecting + self.close() + + class VecVideoRecorder(VecEnvWrapper): """ Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video. @@ -20,7 +189,7 @@ class VecVideoRecorder(VecEnvWrapper): :param name_prefix: Prefix to the video name """ - # video_recorder: video_recorder.VideoRecorder + video_recorder: VideoRecorder def __init__( self, @@ -71,9 +240,7 @@ def start_video_recorder(self) -> None: video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}" base_path = os.path.join(self.video_folder, video_name) - # self.video_recorder = video_recorder.VideoRecorder( - # env=self.env, base_path=base_path, metadata={"step_id": self.step_id} - # ) + self.video_recorder = VideoRecorder(env=self.env, base_path=base_path, metadata={"step_id": self.step_id}) self.video_recorder.capture_frame() self.recorded_frames = 1 From 2f403da3a086bd45fd4910cb69c0d5022fc26f74 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Wed, 3 Apr 2024 22:31:19 +0100 Subject: [PATCH 10/24] Use `typing.List` rather than list --- stable_baselines3/common/vec_env/vec_video_recorder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index 76804bb41..bf2153e84 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -53,7 +53,7 @@ def __init__( self.enabled = enabled self._closed = False - self.render_history: list[np.ndarray] = [] + self.render_history: List[np.ndarray] = [] self.env = env self.render_mode = env.render_mode @@ -100,7 +100,7 @@ def __init__( self.write_metadata() logger.info(f"Starting new video recorder writing to {self.path}") - self.recorded_frames: list[np.ndarray] = [] + self.recorded_frames: List[np.ndarray] = [] @property def functional(self): From c32e1986ebc5dba29ad4ea3de7c9c910edcf70f4 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Wed, 3 Apr 2024 23:35:05 +0100 Subject: [PATCH 11/24] Fix env attribute forwarding --- tests/test_logger.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_logger.py b/tests/test_logger.py index dfd9e5567..bea724b31 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -553,14 +553,14 @@ def test_rollout_success_rate_on_policy_algorithm(tmp_path): env = Monitor(DummySuccessEnv(dummy_successes, ep_steps), filename=monitor_file, info_keywords=("is_success",)) # Equip the model of a custom logger to check the success_rate info - model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=env.steps_per_log, verbose=1) + model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=env.env.steps_per_log, verbose=1) logger = InMemoryLogger() model.set_logger(logger) # Make the model learn and check that the success rate corresponds to the ratio of dummy successes - model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + model.learn(total_timesteps=env.env.ep_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.3 - model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + model.learn(total_timesteps=env.env.ep_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.5 - model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + model.learn(total_timesteps=env.env.ep_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.8 From 34637a5cb47669c2995e30442580ce51e2b82a68 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Thu, 4 Apr 2024 09:30:05 +0100 Subject: [PATCH 12/24] Separate out env attribute collection from its utilisation --- tests/test_logger.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_logger.py b/tests/test_logger.py index bea724b31..b27f1df41 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -540,6 +540,7 @@ def test_rollout_success_rate_on_policy_algorithm(tmp_path): """ STATS_WINDOW_SIZE = 10 + # Add dummy successes with 0.3, 0.5 and 0.8 success_rate of length STATS_WINDOW_SIZE dummy_successes = [ [True] * 3 + [False] * 7, @@ -551,16 +552,17 @@ def test_rollout_success_rate_on_policy_algorithm(tmp_path): # Monitor the env to track the success info monitor_file = str(tmp_path / "monitor.csv") env = Monitor(DummySuccessEnv(dummy_successes, ep_steps), filename=monitor_file, info_keywords=("is_success",)) + steps_per_log = env.unwrapped.steps_per_log # Equip the model of a custom logger to check the success_rate info - model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=env.env.steps_per_log, verbose=1) + model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=steps_per_log, verbose=1) logger = InMemoryLogger() model.set_logger(logger) # Make the model learn and check that the success rate corresponds to the ratio of dummy successes - model.learn(total_timesteps=env.env.ep_per_log * ep_steps, log_interval=1) + model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.3 - model.learn(total_timesteps=env.env.ep_per_log * ep_steps, log_interval=1) + model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.5 - model.learn(total_timesteps=env.env.ep_per_log * ep_steps, log_interval=1) + model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.8 From aadb895b3d66057ae2511f03588b04b4cc035771 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Tue, 21 May 2024 11:13:46 +0100 Subject: [PATCH 13/24] Update for Gymnasium alpha 2 --- .github/workflows/ci.yml | 8 +------- setup.py | 43 +++++++++++++++------------------------- tests/test_utils.py | 5 ++--- 3 files changed, 19 insertions(+), 37 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b1078cd28..151de07dc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,13 +34,7 @@ jobs: # cpu version of pytorch pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu - # Install Atari Roms - pip install autorom - wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 - base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz - AutoROM --accept-license --source-file Roms.tar.gz - - pip install .[extra_no_roms,tests,docs] + pip install .[extra,tests,docs] # Use headless version pip install opencv-python-headless - name: Lint with ruff diff --git a/setup.py b/setup.py index 756454504..2b2e4f004 100644 --- a/setup.py +++ b/setup.py @@ -70,37 +70,13 @@ """ # noqa:E501 -# Atari Games download is sometimes problematic: -# https://github.com/Farama-Foundation/AutoROM/issues/39 -# That's why we define extra packages without it. -extra_no_roms = [ - # For render - "opencv-python", - "pygame", - # Tensorboard support - "tensorboard>=2.9.1", - # Checking memory taken by replay buffer - "psutil", - # For progress bar callback - "tqdm", - "rich", - # For atari games, - "shimmy[atari]~=1.3.0", - "pillow", -] - -extra_packages = extra_no_roms + [ # noqa: RUF005 - # For atari roms, - "autorom[accept-rom-license]~=0.6.1", -] - setup( name="stable_baselines3", packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gymnasium==1.0.0a1", + "gymnasium>=1.0.0a1", "numpy>=1.20", "torch>=1.13", # For saving models @@ -133,8 +109,21 @@ # Copy button for code snippets "sphinx_copybutton", ], - "extra": extra_packages, - "extra_no_roms": extra_no_roms, + "extra": [ + # For render + "opencv-python", + "pygame", + # Tensorboard support + "tensorboard>=2.9.1", + # Checking memory taken by replay buffer + "psutil", + # For progress bar callback + "tqdm", + "rich", + # For atari games, + "ale-py>=0.9.0", + "pillow", + ], }, description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.", author="Antonin Raffin", diff --git a/tests/test_utils.py b/tests/test_utils.py index 63eacafba..01227855b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,12 +1,12 @@ import os import shutil +import ale_py import gymnasium as gym import numpy as np import pytest import torch as th from gymnasium import spaces -from shimmy import registration import stable_baselines3 as sb3 from stable_baselines3 import A2C @@ -25,8 +25,7 @@ ) from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv -# a hack to get atari environment registered for 1.0.0 alpha 1 -registration._register_atari_envs() +gym.register_envs(ale_py) @pytest.mark.parametrize("env_id", ["CartPole-v1", lambda: gym.make("CartPole-v1")]) From 0890cd450b325296866869c1f997a4bb1939805d Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Tue, 21 May 2024 11:37:30 +0100 Subject: [PATCH 14/24] Remove assert for OrderedDict --- stable_baselines3/common/vec_env/subproc_vec_env.py | 2 +- stable_baselines3/common/vec_env/util.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index a29bd4dc9..9ebd16c8f 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -221,7 +221,7 @@ def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Sp assert len(obs) > 0, "need observations from at least one environment" if isinstance(space, spaces.Dict): - assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces" + assert isinstance(space.spaces, dict), "Dict space must have ordered subspaces" assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) elif isinstance(space, spaces.Tuple): diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 855f50edc..2bb66a295 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -19,7 +19,7 @@ def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: :param obs: a dict of numpy arrays. :return: a dict of copied numpy arrays. """ - assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'" + assert isinstance(obs, dict), f"unexpected type for observations '{type(obs)}'" return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) @@ -60,7 +60,7 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[ """ check_for_nested_spaces(obs_space) if isinstance(obs_space, spaces.Dict): - assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces" + assert isinstance(obs_space.spaces, dict), "Dict space must have ordered subspaces" subspaces = obs_space.spaces elif isinstance(obs_space, spaces.Tuple): subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment] From e5b7104e3cf052982f18f528efe1b1117494e69f Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 29 Jun 2024 20:19:26 +0200 Subject: [PATCH 15/24] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index fb5d08cf0..b89d59cba 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gymnasium>=1.0.0a1<1.1.0", + "gymnasium>=1.0.0a1,<1.1.0", "numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302 "torch>=1.13", # For saving models From 868b3031d8ffb26149d74c6ac0d2748819e4869c Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Thu, 22 Aug 2024 13:15:35 +0100 Subject: [PATCH 16/24] Add type: ignore --- stable_baselines3/common/vec_env/subproc_vec_env.py | 2 +- stable_baselines3/common/vec_env/util.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 9ebd16c8f..78087fe97 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -223,7 +223,7 @@ def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Sp if isinstance(space, spaces.Dict): assert isinstance(space.spaces, dict), "Dict space must have ordered subspaces" assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" - return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) + return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) # type: ignore[call-overload] elif isinstance(space, spaces.Tuple): assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" obs_len = len(space.spaces) diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 2bb66a295..9e80f2859 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -63,10 +63,10 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[ assert isinstance(obs_space.spaces, dict), "Dict space must have ordered subspaces" subspaces = obs_space.spaces elif isinstance(obs_space, spaces.Tuple): - subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment] + subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment,misc] else: assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'" - subspaces = {None: obs_space} # type: ignore[assignment] + subspaces = {None: obs_space} # type: ignore[assignment,dict-item] keys = [] shapes = {} dtypes = {} @@ -74,4 +74,4 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[ keys.append(key) shapes[key] = box.shape dtypes[key] = box.dtype - return keys, shapes, dtypes + return keys, shapes, dtypes # type: ignore[return-value] From 5c0fca61de072c7830d567c6e2669bd94c81d5a1 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Thu, 22 Aug 2024 13:15:46 +0100 Subject: [PATCH 17/24] Test with Gymnasium main --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f394eb137..e46062014 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,6 +37,7 @@ jobs: pip install .[extra,tests,docs] # Use headless version pip install opencv-python-headless + pip install git+https://github.com/farama-Foundation/gymnasium - name: Lint with ruff run: | make lint From 4a44f5068a422a3b2b602d8335d773e3d2c131a3 Mon Sep 17 00:00:00 2001 From: pseudo-rnd-thoughts Date: Thu, 22 Aug 2024 13:25:10 +0100 Subject: [PATCH 18/24] Remove `gymnasium.logger.debug/info` --- stable_baselines3/common/vec_env/vec_video_recorder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index bf2153e84..0ca11a7b2 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -99,7 +99,7 @@ def __init__( self.metadata_path = f"{path_base}.meta.json" self.write_metadata() - logger.info(f"Starting new video recorder writing to {self.path}") + # logger.debug(f"Starting new video recorder writing to {self.path}") self.recorded_frames: List[np.ndarray] = [] @property @@ -119,7 +119,7 @@ def capture_frame(self): if self._closed: logger.warn("The video recorder has been closed and no frames will be captured anymore.") return - logger.debug("Capturing video frame: path=%s", self.path) + # logger.debug(f"Capturing video frame: path={self.path}") if frame is None: if self._async: @@ -150,7 +150,7 @@ def close(self): except ImportError as e: raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install moviepy`") from e - logger.debug(f"Closing video encoder: path={self.path}") + # logger.debug(f"Closing video encoder: path={self.path}") clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec) clip.write_videofile(self.path) else: From 3b48d27805696c09bc49a589d8274c551762be55 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 2 Nov 2024 08:41:00 +0100 Subject: [PATCH 19/24] Fix github CI yaml --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index be431bd3c..80377427f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,7 +47,7 @@ jobs: uv pip install --system .[extra,tests,docs] # Use headless version uv pip install --system opencv-python-headless - - name: Install specific version of gym + - name: Install specific version of gym run: | uv pip install --system gymnasium==${{ matrix.gymnasium-version }} # Only run for python 3.10, downgrade gym to 0.29.1 From 0f97c3b8d5a6bcf534babb806397a52091470507 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 2 Nov 2024 08:43:07 +0100 Subject: [PATCH 20/24] Run gym 0.29.1 on python 3.10 --- .github/workflows/ci.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 80377427f..d34a93c9a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,8 +27,6 @@ jobs: # Add a new config to test gym<1.0 - python-version: "3.10" gymnasium-version: "0.29.1" - - python-version: "3.10" - gymnasium-version: "1.0.0" steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} From 1b10cef8914b57fef3f84319f9fc97ffdb617646 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 2 Nov 2024 10:08:20 +0100 Subject: [PATCH 21/24] Update lower bounds --- docs/conda_env.yml | 2 +- docs/misc/changelog.rst | 2 +- setup.py | 2 +- stable_baselines3/common/vec_env/patch_gym.py | 2 +- stable_baselines3/version.txt | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/conda_env.yml b/docs/conda_env.yml index e025a57e1..ac065b3b9 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -8,7 +8,7 @@ dependencies: - python=3.11 - pytorch=2.5.0=py3.11_cpu_0 - pip: - - gymnasium>=0.28.1,<0.30 + - gymnasium>=0.29.1,<1.1.0 - cloudpickle - opencv-python-headless - pandas diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b32cd7ce1..3587cc670 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a10 (WIP) +Release 2.4.0a11 (WIP) -------------------------- **New algorithm: CrossQ in SB3 Contrib** diff --git a/setup.py b/setup.py index 4a7561a48..52f626462 100644 --- a/setup.py +++ b/setup.py @@ -76,7 +76,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gymnasium>=0.28.1,<1.1.0", + "gymnasium>=0.29.1,<1.1.0", "numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302 "torch>=1.13", # For saving models diff --git a/stable_baselines3/common/vec_env/patch_gym.py b/stable_baselines3/common/vec_env/patch_gym.py index 6ba655ebf..874809a03 100644 --- a/stable_baselines3/common/vec_env/patch_gym.py +++ b/stable_baselines3/common/vec_env/patch_gym.py @@ -43,7 +43,7 @@ def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma "Missing shimmy installation. You provided an OpenAI Gym environment. " "Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. " "In order to use OpenAI Gym environments with SB3, you need to " - "install shimmy (`pip install 'shimmy>=0.2.1'`)." + "install shimmy (`pip install 'shimmy>=2.0'`)." ) from e warnings.warn( diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 852a32b3f..d5cafdb5a 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a10 +2.4.0a11 From 45cd5f8eb628b14d433d2142385d6dcc71a3653d Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 2 Nov 2024 11:08:57 +0100 Subject: [PATCH 22/24] Integrate video recorder --- .../common/vec_env/vec_video_recorder.py | 268 +++++------------- 1 file changed, 72 insertions(+), 196 deletions(-) diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index 0ca11a7b2..e586f94ab 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -1,8 +1,6 @@ -import json import os import os.path -import tempfile -from typing import Callable, List, Optional +from typing import Callable, List import numpy as np from gymnasium import error, logger @@ -12,174 +10,16 @@ from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv -# This is copy and pasted from Gymnasium v0.26.1 -class VideoRecorder: - """VideoRecorder renders a nice movie of a rollout, frame by frame. - - It comes with an ``enabled`` option, so you can still use the same code on episodes where you don't want to record video. - - Note: - You are responsible for calling :meth:`close` on a created VideoRecorder, or else you may leak an encoder process. - """ - - def __init__( - self, - env, - path: Optional[str] = None, - metadata: Optional[dict] = None, - enabled: bool = True, - base_path: Optional[str] = None, - ): - """Video recorder renders a nice movie of a rollout, frame by frame. - - Args: - env (Env): Environment to take video of. - path (Optional[str]): Path to the video file; will be randomly chosen if omitted. - metadata (Optional[dict]): Contents to save to the metadata file. - enabled (bool): Whether to actually record video, or just no-op (for convenience) - base_path (Optional[str]): Alternatively, path to the video file without extension, which will be added. - - Raises: - Error: You can pass at most one of `path` or `base_path` - Error: Invalid path given that must have a particular file extension - """ - try: - # check that moviepy is now installed - import moviepy # noqa: F401 - except ImportError as e: - raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install moviepy`") from e - - self._async = env.metadata.get("semantics.async") - self.enabled = enabled - self._closed = False - - self.render_history: List[np.ndarray] = [] - self.env = env - - self.render_mode = env.render_mode - - if "rgb_array_list" != self.render_mode and "rgb_array" != self.render_mode: - logger.warn( - f"Disabling video recorder because environment {env} was not initialized with any compatible video " - "mode between `rgb_array` and `rgb_array_list`" - ) - # Disable since the environment has not been initialized with a compatible `render_mode` - self.enabled = False - - # Don't bother setting anything else if not enabled - if not self.enabled: - return - - if path is not None and base_path is not None: - raise error.Error("You can pass at most one of `path` or `base_path`.") - - required_ext = ".mp4" - if path is None: - if base_path is not None: - # Base path given, append ext - path = base_path + required_ext - else: - # Otherwise, just generate a unique filename - with tempfile.NamedTemporaryFile(suffix=required_ext) as f: - path = f.name - self.path = path - - path_base, actual_ext = os.path.splitext(self.path) - - if actual_ext != required_ext: - raise error.Error(f"Invalid path given: {self.path} -- must have file extension {required_ext}.") - - self.frames_per_sec = env.metadata.get("render_fps", 30) - - self.broken = False - - # Dump metadata - self.metadata = metadata or {} - self.metadata["content_type"] = "video/mp4" - self.metadata_path = f"{path_base}.meta.json" - self.write_metadata() - - # logger.debug(f"Starting new video recorder writing to {self.path}") - self.recorded_frames: List[np.ndarray] = [] - - @property - def functional(self): - """Returns if the video recorder is functional, is enabled and not broken.""" - return self.enabled and not self.broken - - def capture_frame(self): - """Render the given `env` and add the resulting frame to the video.""" - frame = self.env.render() - if isinstance(frame, List): - self.render_history += frame - frame = frame[-1] - - if not self.functional: - return - if self._closed: - logger.warn("The video recorder has been closed and no frames will be captured anymore.") - return - # logger.debug(f"Capturing video frame: path={self.path}") - - if frame is None: - if self._async: - return - else: - # Indicates a bug in the environment: don't want to raise - # an error here. - logger.warn( - "Env returned None on `render()`. Disabling further rendering for video recorder by marking as " - f"disabled: path={self.path} metadata_path={self.metadata_path}" - ) - self.broken = True - else: - self.recorded_frames.append(frame) - - def close(self): - """Flush all data to disk and close any open frame encoders.""" - if not self.enabled or self._closed: - return - - # First close the environment - self.env.close() - - # Close the encoder - if len(self.recorded_frames) > 0: - try: - from moviepy.video.io.ImageSequenceClip import ImageSequenceClip - except ImportError as e: - raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install moviepy`") from e - - # logger.debug(f"Closing video encoder: path={self.path}") - clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec) - clip.write_videofile(self.path) - else: - # No frames captured. Set metadata. - if self.metadata is None: - self.metadata = {} - self.metadata["empty"] = True - - self.write_metadata() - - # Stop tracking this for autoclose - self._closed = True - - def write_metadata(self): - """Writes metadata to metadata path.""" - with open(self.metadata_path, "w") as f: - json.dump(self.metadata, f) - - def __del__(self): - """Closes the environment correctly when the recorder is deleted.""" - # Make sure we've closed up shop when garbage collecting - self.close() - - class VecVideoRecorder(VecEnvWrapper): """ Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video. It requires ffmpeg or avconv to be installed on the machine. + Note: for now it only allows to record one video and all videos + must have at least two frames. + + The video recorder code was adapted from Gymnasium v1.0. + :param venv: :param video_folder: Where to save videos :param record_video_trigger: Function that defines when to start recording. @@ -189,8 +29,6 @@ class VecVideoRecorder(VecEnvWrapper): :param name_prefix: Prefix to the video name """ - video_recorder: VideoRecorder - def __init__( self, venv: VecEnv, @@ -218,6 +56,8 @@ def __init__( self.env.metadata = metadata assert self.env.render_mode == "rgb_array", f"The render_mode must be 'rgb_array', not {self.env.render_mode}" + self.frames_per_sec = self.env.metadata.get("render_fps", 30) + self.record_video_trigger = record_video_trigger self.video_folder = os.path.abspath(video_folder) # Create output folder if needed @@ -227,52 +67,88 @@ def __init__( self.step_id = 0 self.video_length = video_length + self.video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}.mp4" + self.video_path = os.path.join(self.video_folder, self.video_name) + self.recording = False - self.recorded_frames = 0 + self.recorded_frames: list[np.ndarray] = [] + + try: + import moviepy # noqa: F401 + except ImportError as e: + raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install 'gymnasium[other]'`") from e def reset(self) -> VecEnvObs: obs = self.venv.reset() - self.start_video_recorder() + if self._video_enabled(): + self._start_video_recorder() return obs - def start_video_recorder(self) -> None: - self.close_video_recorder() - - video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}" - base_path = os.path.join(self.video_folder, video_name) - self.video_recorder = VideoRecorder(env=self.env, base_path=base_path, metadata={"step_id": self.step_id}) - - self.video_recorder.capture_frame() - self.recorded_frames = 1 - self.recording = True + def _start_video_recorder(self) -> None: + self._start_recording() + self._capture_frame() def _video_enabled(self) -> bool: return self.record_video_trigger(self.step_id) def step_wait(self) -> VecEnvStepReturn: - obs, rews, dones, infos = self.venv.step_wait() + obs, rewards, dones, infos = self.venv.step_wait() self.step_id += 1 if self.recording: - self.video_recorder.capture_frame() - self.recorded_frames += 1 - if self.recorded_frames > self.video_length: - print(f"Saving video to {self.video_recorder.path}") - self.close_video_recorder() + self._capture_frame() + if len(self.recorded_frames) > self.video_length: + print(f"Saving video to {self.video_path}") + self._stop_recording() elif self._video_enabled(): - self.start_video_recorder() + self._start_video_recorder() - return obs, rews, dones, infos + return obs, rewards, dones, infos - def close_video_recorder(self) -> None: - if self.recording: - self.video_recorder.close() - self.recording = False - self.recorded_frames = 1 + def _capture_frame(self) -> None: + assert self.recording, "Cannot capture a frame, recording wasn't started." + + frame = self.env.render() + if isinstance(frame, List): + frame = frame[-1] + + if isinstance(frame, np.ndarray): + self.recorded_frames.append(frame) + else: + self._stop_recording() + logger.warn( + f"Recording stopped: expected type of frame returned by render to be a numpy array, got instead {type(frame)}." + ) def close(self) -> None: + """Closes the wrapper then the video recorder.""" VecEnvWrapper.close(self) - self.close_video_recorder() + if self.recording: + self._stop_recording() + + def _start_recording(self) -> None: + """Start a new recording. If it is already recording, stops the current recording before starting the new one.""" + if self.recording: + self._stop_recording() - def __del__(self): - self.close_video_recorder() + self.recording = True + + def _stop_recording(self) -> None: + """Stop current recording and saves the video.""" + assert self.recording, "_stop_recording was called, but no recording was started" + + if len(self.recorded_frames) == 0: + logger.warn("Ignored saving a video as there were zero frames to save.") + else: + from moviepy.video.io.ImageSequenceClip import ImageSequenceClip + + clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec) + clip.write_videofile(self.video_path) + + self.recorded_frames = [] + self.recording = False + + def __del__(self) -> None: + """Warn the user in case last video wasn't saved.""" + if len(self.recorded_frames) > 0: + logger.warn("Unable to save last video! Did you call close()?") From cba9a2cdbe70ea65a1baa0ad6b0f9a4ea6083a58 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 3 Nov 2024 18:31:21 +0100 Subject: [PATCH 23/24] Remove ordered dict --- .../common/vec_env/dummy_vec_env.py | 4 +-- .../common/vec_env/subproc_vec_env.py | 28 +++++++++---------- stable_baselines3/common/vec_env/util.py | 12 -------- tests/test_vec_envs.py | 2 +- 4 files changed, 17 insertions(+), 29 deletions(-) diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index a37d3f254..5625e2453 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -8,7 +8,7 @@ from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn from stable_baselines3.common.vec_env.patch_gym import _patch_env -from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info +from stable_baselines3.common.vec_env.util import dict_to_obs, obs_space_info class DummyVecEnv(VecEnv): @@ -110,7 +110,7 @@ def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None: self.buf_obs[key][env_idx] = obs[key] # type: ignore[call-overload] def _obs_from_buf(self) -> VecEnvObs: - return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs)) + return dict_to_obs(self.observation_space, deepcopy(self.buf_obs)) def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: """Return attribute from vectorized environment (see base class).""" diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 78087fe97..a606a7cb9 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -1,6 +1,5 @@ import multiprocessing as mp import warnings -from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union import gymnasium as gym @@ -129,7 +128,7 @@ def step_wait(self) -> VecEnvStepReturn: results = [remote.recv() for remote in self.remotes] self.waiting = False obs, rews, dones, infos, self.reset_infos = zip(*results) # type: ignore[assignment] - return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value] + return _stack_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value] def reset(self) -> VecEnvObs: for env_idx, remote in enumerate(self.remotes): @@ -139,7 +138,7 @@ def reset(self) -> VecEnvObs: # Seeds and options are only used once self._reset_seeds() self._reset_options() - return _flatten_obs(obs, self.observation_space) + return _stack_obs(obs, self.observation_space) def close(self) -> None: if self.closed: @@ -206,27 +205,28 @@ def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]: return [self.remotes[i] for i in indices] -def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs: +def _stack_obs(obs_list: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs: """ - Flatten observations, depending on the observation space. + Stack observations (convert from a list of single env obs to a stack of obs), + depending on the observation space. :param obs: observations. A list or tuple of observations, one per environment. Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays. - :return: flattened observations. - A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays. + :return: Concatenated observations. + A NumPy array or a dict or tuple of stacked numpy arrays. Each NumPy array has the environment index as its first axis. """ - assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" - assert len(obs) > 0, "need observations from at least one environment" + assert isinstance(obs_list, (list, tuple)), "expected list or tuple of observations per environment" + assert len(obs_list) > 0, "need observations from at least one environment" if isinstance(space, spaces.Dict): assert isinstance(space.spaces, dict), "Dict space must have ordered subspaces" - assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" - return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) # type: ignore[call-overload] + assert isinstance(obs_list[0], dict), "non-dict observation for environment with Dict observation space" + return {key: np.stack([single_obs[key] for single_obs in obs_list]) for key in space.spaces.keys()} # type: ignore[call-overload] elif isinstance(space, spaces.Tuple): - assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" + assert isinstance(obs_list[0], tuple), "non-tuple observation for environment with Tuple observation space" obs_len = len(space.spaces) - return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) # type: ignore[index] + return tuple(np.stack([single_obs[i] for single_obs in obs_list]) for i in range(obs_len)) # type: ignore[index] else: - return np.stack(obs) # type: ignore[arg-type] + return np.stack(obs_list) # type: ignore[arg-type] diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 9e80f2859..6ea04f6ab 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -2,7 +2,6 @@ Helpers for dealing with vectorized environments. """ -from collections import OrderedDict from typing import Any, Dict, List, Tuple import numpy as np @@ -12,17 +11,6 @@ from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs -def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: - """ - Deep-copy a dict of numpy arrays. - - :param obs: a dict of numpy arrays. - :return: a dict of copied numpy arrays. - """ - assert isinstance(obs, dict), f"unexpected type for observations '{type(obs)}'" - return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) - - def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs: """ Convert an internal representation raw_obs into the appropriate type diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index a9516ae25..3aa52762d 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -307,7 +307,7 @@ def test_vecenv_dict_spaces(vec_env_class): space = spaces.Dict(SPACES) def obs_assert(obs): - assert isinstance(obs, collections.OrderedDict) + assert isinstance(obs, dict) assert obs.keys() == space.spaces.keys() for key, values in obs.items(): check_vecenv_obs(values, space.spaces[key]) From df5fdaa5ac0a3c7771957c4479db3ea14eed45b2 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 3 Nov 2024 18:36:10 +0100 Subject: [PATCH 24/24] Update changelog --- docs/misc/changelog.rst | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 3587cc670..cf2a2a520 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -6,7 +6,7 @@ Changelog Release 2.4.0a11 (WIP) -------------------------- -**New algorithm: CrossQ in SB3 Contrib** +**New algorithm: CrossQ in SB3 Contrib, Gymnasium v1.0 support** .. note:: @@ -24,12 +24,14 @@ Release 2.4.0a11 (WIP) Breaking Changes: ^^^^^^^^^^^^^^^^^ +- Increase minimum required version of Gymnasium to 0.29.1 New Features: ^^^^^^^^^^^^^ - Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ) - Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle) - Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces +- Added support for Gymnasium v1.0 Bug Fixes: ^^^^^^^^^^ @@ -69,6 +71,7 @@ Others: - Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy`` - Switched to uv to download packages faster on GitHub CI - Updated dependencies for read the doc +- Removed unnecessary ``copy_obs_dict`` method for ``SubprocVecEnv``, remove the use of ordered dict and rename ``flatten_obs`` to ``stack_obs`` Bug Fixes: ^^^^^^^^^^