From 000544cc1fe6a1c1ec80c125dadad11ad49e1473 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 22 Jul 2024 13:42:33 +0200 Subject: [PATCH 01/11] Add support for pre and post linear modules in `create_mlp` (#1975) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add support for pre and post linear modules in `create_mlp` * Disable mypy for python 3.8 * Reformat toml file * Update docstring Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Add some comments --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- .github/workflows/ci.yml | 72 ++++++++++++------------ docs/misc/changelog.rst | 3 +- pyproject.toml | 32 ++++++----- stable_baselines3/common/torch_layers.py | 53 ++++++++++++++--- stable_baselines3/version.txt | 2 +- tests/test_custom_policy.py | 56 ++++++++++++++++++ 6 files changed, 157 insertions(+), 61 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0efc16e56..822e0cb3f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,9 +5,9 @@ name: CI on: push: - branches: [ master ] + branches: [master] pull_request: - branches: [ master ] + branches: [master] jobs: build: @@ -23,38 +23,40 @@ jobs: python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - # cpu version of pytorch - pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + # cpu version of pytorch + pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu - # Install Atari Roms - pip install autorom - wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 - base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz - AutoROM --accept-license --source-file Roms.tar.gz + # Install Atari Roms + pip install autorom + wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 + base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz + AutoROM --accept-license --source-file Roms.tar.gz - pip install .[extra_no_roms,tests,docs] - # Use headless version - pip install opencv-python-headless - - name: Lint with ruff - run: | - make lint - - name: Build the doc - run: | - make doc - - name: Check codestyle - run: | - make check-codestyle - - name: Type check - run: | - make type - - name: Test with pytest - run: | - make pytest + pip install .[extra_no_roms,tests,docs] + # Use headless version + pip install opencv-python-headless + - name: Lint with ruff + run: | + make lint + - name: Build the doc + run: | + make doc + - name: Check codestyle + run: | + make check-codestyle + - name: Type check + run: | + make type + # Do not run for python 3.8 (mypy internal error) + if: matrix.python-version != '3.8' + - name: Test with pytest + run: | + make pytest diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 78eb2bd0e..31ff99d09 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a5 (WIP) +Release 2.4.0a6 (WIP) -------------------------- Breaking Changes: @@ -11,6 +11,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ) Bug Fixes: ^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index 8e20ffe00..dd435a33e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,10 +13,10 @@ ignore = ["B028", "RUF013"] [tool.ruff.lint.per-file-ignores] # Default implementation in abstract methods -"./stable_baselines3/common/callbacks.py"= ["B027"] -"./stable_baselines3/common/noise.py"= ["B027"] +"./stable_baselines3/common/callbacks.py" = ["B027"] +"./stable_baselines3/common/noise.py" = ["B027"] # ClassVar, implicit optional check not needed for tests -"./tests/*.py"= ["RUF012", "RUF013"] +"./tests/*.py" = ["RUF012", "RUF013"] [tool.ruff.lint.mccabe] @@ -37,9 +37,7 @@ exclude = """(?x)( [tool.pytest.ini_options] # Deterministic ordering for tests; useful for pytest-xdist. -env = [ - "PYTHONHASHSEED=0" -] +env = ["PYTHONHASHSEED=0"] filterwarnings = [ # Tensorboard warnings @@ -47,23 +45,27 @@ filterwarnings = [ # Gymnasium warnings "ignore::UserWarning:gymnasium", # tqdm warning about rich being experimental - "ignore:rich is experimental" + "ignore:rich is experimental", ] markers = [ - "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')" + "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')", ] [tool.coverage.run] disable_warnings = ["couldnt-parse"] branch = false omit = [ - "tests/*", - "setup.py", - # Require graphical interface - "stable_baselines3/common/results_plotter.py", - # Require ffmpeg - "stable_baselines3/common/vec_env/vec_video_recorder.py", + "tests/*", + "setup.py", + # Require graphical interface + "stable_baselines3/common/results_plotter.py", + # Require ffmpeg + "stable_baselines3/common/vec_env/vec_video_recorder.py", ] [tool.coverage.report] -exclude_lines = [ "pragma: no cover", "raise NotImplementedError()", "if typing.TYPE_CHECKING:"] +exclude_lines = [ + "pragma: no cover", + "raise NotImplementedError()", + "if typing.TYPE_CHECKING:", +] diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index bb3ba5de8..234b91551 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple, Type, Union +from typing import Dict, List, Optional, Tuple, Type, Union import gymnasium as gym import torch as th @@ -14,7 +14,7 @@ class BaseFeaturesExtractor(nn.Module): """ Base class that represents a features extractor. - :param observation_space: + :param observation_space: The observation space of the environment :param features_dim: Number of features extracted. """ @@ -26,6 +26,7 @@ def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None: @property def features_dim(self) -> int: + """The number of features that the extractor outputs.""" return self._features_dim @@ -34,7 +35,7 @@ class FlattenExtractor(BaseFeaturesExtractor): Feature extract that flatten the input. Used as a placeholder when feature extraction is not needed. - :param observation_space: + :param observation_space: The observation space of the environment """ def __init__(self, observation_space: gym.Space) -> None: @@ -52,7 +53,7 @@ class NatureCNN(BaseFeaturesExtractor): "Human-level control through deep reinforcement learning." Nature 518.7540 (2015): 529-533. - :param observation_space: + :param observation_space: The observation space of the environment :param features_dim: Number of features extracted. This corresponds to the number of unit for the last layer. :param normalized_image: Whether to assume that the image is already normalized @@ -113,13 +114,15 @@ def create_mlp( activation_fn: Type[nn.Module] = nn.ReLU, squash_output: bool = False, with_bias: bool = True, + pre_linear_modules: Optional[List[Type[nn.Module]]] = None, + post_linear_modules: Optional[List[Type[nn.Module]]] = None, ) -> List[nn.Module]: """ Create a multi layer perceptron (MLP), which is a collection of fully-connected layers each followed by an activation function. :param input_dim: Dimension of the input vector - :param output_dim: + :param output_dim: Dimension of the output (last layer, for instance, the number of actions) :param net_arch: Architecture of the neural net It represents the number of units per layer. The length of this list is the number of layers. @@ -128,20 +131,52 @@ def create_mlp( :param squash_output: Whether to squash the output using a Tanh activation function :param with_bias: If set to False, the layers will not learn an additive bias - :return: + :param pre_linear_modules: List of nn.Module to add before the linear layers. + These modules should maintain the input tensor dimension (e.g. BatchNorm). + The number of input features is passed to the module's constructor. + Compared to post_linear_modules, they are used before the output layer (output_dim > 0). + :param post_linear_modules: List of nn.Module to add after the linear layers + (and before the activation function). These modules should maintain the input + tensor dimension (e.g. Dropout, LayerNorm). They are not used after the + output layer (output_dim > 0). The number of input features is passed to + the module's constructor. + :return: The list of layers of the neural network """ + pre_linear_modules = pre_linear_modules or [] + post_linear_modules = post_linear_modules or [] + + modules = [] if len(net_arch) > 0: - modules = [nn.Linear(input_dim, net_arch[0], bias=with_bias), activation_fn()] - else: - modules = [] + # BatchNorm maintains input dim + for module in pre_linear_modules: + modules.append(module(input_dim)) + + modules.append(nn.Linear(input_dim, net_arch[0], bias=with_bias)) + + # LayerNorm, Dropout maintain output dim + for module in post_linear_modules: + modules.append(module(net_arch[0])) + + modules.append(activation_fn()) for idx in range(len(net_arch) - 1): + for module in pre_linear_modules: + modules.append(module(net_arch[idx])) + modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1], bias=with_bias)) + + for module in post_linear_modules: + modules.append(module(net_arch[idx + 1])) + modules.append(activation_fn()) if output_dim > 0: last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim + # Only add BatchNorm before output layer + for module in pre_linear_modules: + modules.append(module(last_layer_dim)) + modules.append(nn.Linear(last_layer_dim, output_dim, bias=with_bias)) if squash_output: modules.append(nn.Tanh()) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index a1fd35b5f..464a5c4dc 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a5 +2.4.0a6 diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index 1f89b23d6..e92ffe8b7 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -1,8 +1,10 @@ import pytest import torch as th +import torch.nn as nn from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike +from stable_baselines3.common.torch_layers import create_mlp @pytest.mark.parametrize( @@ -62,3 +64,57 @@ def test_tf_like_rmsprop_optimizer(): def test_dqn_custom_policy(): policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32]) _ = DQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, learning_starts=100).learn(300) + + +def test_create_mlp(): + net = create_mlp(4, 2, net_arch=[16, 8], squash_output=True) + # We cannot compare the network directly because the modules have different ids + # assert net == [nn.Linear(4, 16), nn.ReLU(), nn.Linear(16, 8), nn.ReLU(), nn.Linear(8, 2), + # nn.Tanh()] + assert len(net) == 6 + assert isinstance(net[0], nn.Linear) + assert net[0].in_features == 4 + assert net[0].out_features == 16 + assert isinstance(net[1], nn.ReLU) + assert isinstance(net[2], nn.Linear) + assert isinstance(net[4], nn.Linear) + assert net[4].in_features == 8 + assert net[4].out_features == 2 + assert isinstance(net[5], nn.Tanh) + + # Linear network + net = create_mlp(4, -1, net_arch=[]) + assert net == [] + + # No output layer, with custom activation function + net = create_mlp(6, -1, net_arch=[8], activation_fn=nn.Tanh) + # assert net == [nn.Linear(6, 8), nn.Tanh()] + assert len(net) == 2 + assert isinstance(net[0], nn.Linear) + assert net[0].in_features == 6 + assert net[0].out_features == 8 + assert isinstance(net[1], nn.Tanh) + + # Using pre-linear and post-linear modules + pre_linear = [nn.BatchNorm1d] + post_linear = [nn.LayerNorm] + net = create_mlp(6, 2, net_arch=[8, 12], pre_linear_modules=pre_linear, post_linear_modules=post_linear) + # assert net == [nn.BatchNorm1d(6), nn.Linear(6, 8), nn.LayerNorm(8), nn.ReLU() + # nn.BatchNorm1d(6), nn.Linear(8, 12), nn.LayerNorm(12), nn.ReLU(), + # nn.BatchNorm1d(12), nn.Linear(12, 2)] # Last layer does not have post_linear + assert len(net) == 10 + assert isinstance(net[0], nn.BatchNorm1d) + assert net[0].num_features == 6 + assert isinstance(net[1], nn.Linear) + assert isinstance(net[2], nn.LayerNorm) + assert isinstance(net[3], nn.ReLU) + assert isinstance(net[4], nn.BatchNorm1d) + assert isinstance(net[5], nn.Linear) + assert net[5].in_features == 8 + assert net[5].out_features == 12 + assert isinstance(net[6], nn.LayerNorm) + assert isinstance(net[7], nn.ReLU) + assert isinstance(net[8], nn.BatchNorm1d) + assert isinstance(net[-1], nn.Linear) + assert net[-1].in_features == 12 + assert net[-1].out_features == 2 From bd3c0c653068a6af1993df7be1a12acfb4be0127 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 26 Jul 2024 14:57:55 +0200 Subject: [PATCH 02/11] Fix loading of optimizer with older DQN models (#1978) --- docs/misc/changelog.rst | 11 ++++++++++- stable_baselines3/common/base_class.py | 27 ++++++++++++++++++++++++-- stable_baselines3/version.txt | 2 +- tests/test_save_load.py | 14 ++++++++++++- 4 files changed, 49 insertions(+), 5 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 31ff99d09..37a035478 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,9 +3,16 @@ Changelog ========== -Release 2.4.0a6 (WIP) +Release 2.4.0a7 (WIP) -------------------------- +.. note:: + + DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about + truncation of optimizer state when loaded with SB3 >= 2.4.0. + To suppress the warning, simply save the model again. + You can find more info in `PR #1963 `_ + Breaking Changes: ^^^^^^^^^^^^^^^^^ @@ -28,9 +35,11 @@ Bug Fixes: `RL Zoo`_ ^^^^^^^^^ +- Updated defaults hyperparameters for TQC/SAC for Swimmer-v4 (decrease gamma for more consistent results) `SBX`_ (SB3 + Jax) ^^^^^^^^^^^^^^^^^^ +- Added CNN support for DQN Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index b2c967405..e43955f94 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -742,13 +742,13 @@ def load( # noqa: C901 # put state_dicts back in place model.set_parameters(params, exact_match=True, device=device) except RuntimeError as e: - # Patch to load Policy saved using SB3 < 1.7.0 + # Patch to load policies saved using SB3 < 1.7.0 # the error is probably due to old policy being loaded # See https://github.com/DLR-RM/stable-baselines3/issues/1233 if "pi_features_extractor" in str(e) and "Missing key(s) in state_dict" in str(e): model.set_parameters(params, exact_match=False, device=device) warnings.warn( - "You are probably loading a model saved with SB3 < 1.7.0, " + "You are probably loading a A2C/PPO model saved with SB3 < 1.7.0, " "we deactivated exact_match so you can save the model " "again to avoid issues in the future " "(see https://github.com/DLR-RM/stable-baselines3/issues/1233 for more info). " @@ -757,6 +757,29 @@ def load( # noqa: C901 ) else: raise e + except ValueError as e: + # Patch to load DQN policies saved using SB3 < 2.4.0 + # The target network params are no longer in the optimizer + # See https://github.com/DLR-RM/stable-baselines3/pull/1963 + saved_optim_params = params["policy.optimizer"]["param_groups"][0]["params"] # type: ignore[index] + n_params_saved = len(saved_optim_params) + n_params = len(model.policy.optimizer.param_groups[0]["params"]) + if n_params_saved == 2 * n_params: + # Truncate to include only online network params + params["policy.optimizer"]["param_groups"][0]["params"] = saved_optim_params[:n_params] # type: ignore[index] + + model.set_parameters(params, exact_match=True, device=device) + warnings.warn( + "You are probably loading a DQN model saved with SB3 < 2.4.0, " + "we truncated the optimizer state so you can save the model " + "again to avoid issues in the future " + "(see https://github.com/DLR-RM/stable-baselines3/pull/1963 for more info). " + f"Original error: {e} \n" + "Note: the model should still work fine, this only a warning." + ) + else: + raise e + # put other pytorch variables back in place if pytorch_variables is not None: for name in pytorch_variables: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 464a5c4dc..f5230e413 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a6 +2.4.0a7 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 5dc6ca7bf..962088246 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -340,7 +340,7 @@ def test_save_load_env_cnn(tmp_path, model_class): # clear file from os os.remove(tmp_path / "test_save.zip") - # Check we can load models saved with SB3 < 1.7.0 + # Check we can load A2C/PPO models saved with SB3 < 1.7.0 if model_class == A2C: del model.policy.pi_features_extractor model.save(tmp_path / "test_save") @@ -809,3 +809,15 @@ def test_save_load_net_arch_none(tmp_path): # None has been replaced by the default net arch assert model.policy.net_arch is not None os.remove(tmp_path / "ppo.zip") + + +def test_save_load_no_target_params(tmp_path): + # Check we can load DQN models saved with SB3 < 2.4.0 + model = DQN("MlpPolicy", "CartPole-v1", buffer_size=10000, learning_starts=4) + env = model.get_env() + # Include target net params + model.policy.optimizer = th.optim.Adam(model.policy.parameters(), lr=0.001) + model.save(tmp_path / "test_save") + with pytest.warns(UserWarning): + DQN.load(str(tmp_path / "test_save.zip"), env=env).learn(20) + os.remove(tmp_path / "test_save.zip") From 6ad6fa55b6e38c8456dd333f71fe45373f66fe90 Mon Sep 17 00:00:00 2001 From: Chris Schindlbeck Date: Mon, 29 Jul 2024 10:44:23 +0200 Subject: [PATCH 03/11] Fix various typos (#1981) --- CODE_OF_CONDUCT.md | 2 +- stable_baselines3/common/on_policy_algorithm.py | 2 +- stable_baselines3/her/her_replay_buffer.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 137c95744..0ca033815 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -5,7 +5,7 @@ We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender -identity and expression, level of experience, education, socio-economic status, +identity and expression, level of experience, education, socioeconomic status, nationality, personal appearance, race, religion, or sexual identity and orientation. diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 1ba36d5f0..262453721 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -208,7 +208,7 @@ def collect_rollouts( # Reshape in case of discrete action actions = actions.reshape(-1, 1) - # Handle timeout by bootstraping with value function + # Handle timeout by bootstrapping with value function # see GitHub issue #633 for idx, done in enumerate(dones): if ( diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index 579c6ebf1..20214e72c 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -396,7 +396,7 @@ def truncate_last_trajectory(self) -> None: "If you are in the same episode as when the replay buffer was saved,\n" "you should use `truncate_last_trajectory=False` to avoid that issue." ) - # only consider epsiodes that are not finished + # only consider episodes that are not finished for env_idx in np.where(self._current_ep_start != self.pos)[0]: # set done = True for last episodes self.dones[self.pos - 1, env_idx] = True From 4a1137ba3ac0ff0ae095aca564edc82ec37b7f1c Mon Sep 17 00:00:00 2001 From: Jan-Hendrik Ewers Date: Fri, 2 Aug 2024 10:55:27 +0100 Subject: [PATCH 04/11] Add np.ndarray as a recognized type for TB histograms. (#1635) * Add np.ndarray as a recognized type for TB histograms. Torch histograms allow th.Tensor, np.ndarray, and caffe2 formatted strings. This commits expands the TensorBoardOutputFormat's capabilities to log the two former types. * Update changelog to reflect bug fix * fix: try/catch for if either np or torch aren't at the required versions. See https://github.com/DLR-RM/stable-baselines3/pull/1635 for more details * fix: Add comment describing the test for when add_histogram should not have been called * Cleanup --------- Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 3 +- stable_baselines3/common/logger.py | 5 ++- stable_baselines3/version.txt | 2 +- tests/test_logger.py | 70 ++++++++++++++++++++++++++---- 4 files changed, 67 insertions(+), 13 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 37a035478..9c461f6ae 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a7 (WIP) +Release 2.4.0a8 (WIP) -------------------------- .. note:: @@ -19,6 +19,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ) +- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle) Bug Fixes: ^^^^^^^^^^ diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 363a9d2e8..8ceda71ed 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -412,8 +412,9 @@ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, . else: self.writer.add_scalar(key, value, step) - if isinstance(value, th.Tensor): - self.writer.add_histogram(key, value, step) + if isinstance(value, (th.Tensor, np.ndarray)): + # Convert to Torch so it works with numpy<1.24 and torch<2.0 + self.writer.add_histogram(key, th.as_tensor(value), step) if isinstance(value, Video): self.writer.add_video(key, value.frames, step, value.fps) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index f5230e413..ee717ba15 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a7 +2.4.0a8 diff --git a/tests/test_logger.py b/tests/test_logger.py index dfa3691ed..bc18bf2ce 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -44,6 +44,7 @@ "f": np.array(1), "g": np.array([[[1]]]), "h": 'this ", ;is a \n tes:,t', + "i": th.ones(3), } KEY_EXCLUDED = {} @@ -176,6 +177,9 @@ def test_main(tmp_path): logger.record_mean("b", -22.5) logger.record_mean("b", -44.4) logger.record("a", 5.5) + # Converted to string: + logger.record("hist1", th.ones(2)) + logger.record("hist2", np.ones(2)) logger.dump() logger.record("a", "longasslongasslongasslongasslongasslongassvalue") @@ -241,7 +245,7 @@ def is_moviepy_installed(): @pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"]) -def test_report_video_to_unsupported_format_raises_error(tmp_path, unsupported_format): +def test_unsupported_video_format(tmp_path, unsupported_format): writer = make_output_format(unsupported_format, tmp_path) with pytest.raises(FormatUnsupportedError) as exec_info: @@ -251,6 +255,54 @@ def test_report_video_to_unsupported_format_raises_error(tmp_path, unsupported_f writer.close() +@pytest.mark.parametrize( + "histogram", + [ + th.rand(100), + np.random.rand(100), + np.ones(1), + np.ones(1, dtype="int"), + ], +) +def test_log_histogram(tmp_path, read_log, histogram): + pytest.importorskip("tensorboard") + + writer = make_output_format("tensorboard", tmp_path) + writer.write({"data": histogram}, key_excluded={"data": ()}) + + log = read_log("tensorboard") + + assert not log.empty + assert any("data" in line for line in log.lines) + assert any("Histogram" in line for line in log.lines) + + writer.close() + + +@pytest.mark.parametrize( + "histogram", + [ + list(np.random.rand(100)), + tuple(np.random.rand(100)), + "1 2 3 4", + np.ones(1).item(), + th.ones(1).item(), + ], +) +def test_unsupported_type_histogram(tmp_path, read_log, histogram): + """ + Check that other types aren't accidentally logged as a Histogram + """ + pytest.importorskip("tensorboard") + + writer = make_output_format("tensorboard", tmp_path) + writer.write({"data": histogram}, key_excluded={"data": ()}) + + assert all("Histogram" not in line for line in read_log("tensorboard").lines) + + writer.close() + + def test_report_image_to_tensorboard(tmp_path, read_log): pytest.importorskip("tensorboard") @@ -263,7 +315,7 @@ def test_report_image_to_tensorboard(tmp_path, read_log): @pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"]) -def test_report_image_to_unsupported_format_raises_error(tmp_path, unsupported_format): +def test_unsupported_image_format(tmp_path, unsupported_format): writer = make_output_format(unsupported_format, tmp_path) with pytest.raises(FormatUnsupportedError) as exec_info: @@ -287,7 +339,7 @@ def test_report_figure_to_tensorboard(tmp_path, read_log): @pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"]) -def test_report_figure_to_unsupported_format_raises_error(tmp_path, unsupported_format): +def test_unsupported_figure_format(tmp_path, unsupported_format): writer = make_output_format(unsupported_format, tmp_path) with pytest.raises(FormatUnsupportedError) as exec_info: @@ -300,7 +352,7 @@ def test_report_figure_to_unsupported_format_raises_error(tmp_path, unsupported_ @pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"]) -def test_report_hparam_to_unsupported_format_raises_error(tmp_path, unsupported_format): +def test_unsupported_hparam(tmp_path, unsupported_format): writer = make_output_format(unsupported_format, tmp_path) with pytest.raises(FormatUnsupportedError) as exec_info: @@ -419,9 +471,9 @@ def test_fps_no_div_zero(algo): model.learn(total_timesteps=100) -def test_human_output_format_no_crash_on_same_keys_different_tags(): - o = HumanOutputFormat(sys.stdout, max_length=60) - o.write( +def test_human_output_same_keys_different_tags(): + human_out = HumanOutputFormat(sys.stdout, max_length=60) + human_out.write( {"key1/foo": "value1", "key1/bar": "value2", "key2/bizz": "value3", "key2/foo": "value4"}, {"key1/foo": None, "key2/bizz": None, "key1/bar": None, "key2/foo": None}, ) @@ -439,7 +491,7 @@ def test_ep_buffers_stats_window_size(algo, stats_window_size): @pytest.mark.parametrize("base_class", [object, TextIOBase]) -def test_human_output_format_custom_test_io(base_class): +def test_human_out_custom_text_io(base_class): class DummyTextIO(base_class): def __init__(self) -> None: super().__init__() @@ -531,7 +583,7 @@ def step(self, action): return self.observation_space.sample(), 0.0, False, truncated, info -def test_rollout_success_rate_on_policy_algorithm(tmp_path): +def test_rollout_success_rate_onpolicy_algo(tmp_path): """ Test if the rollout/success_rate information is correctly logged with on policy algorithms From 4a7631b71d0d7836df977c89303f3428964fdb59 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 18 Aug 2024 12:33:22 +0200 Subject: [PATCH 05/11] Fix test device for buffers (#1993) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Prevent test_device from being a noop * Update changelog --------- Co-authored-by: Adrià Garriga-Alonso --- docs/misc/changelog.rst | 2 ++ tests/test_buffers.py | 21 ++++++++++++++------- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 9c461f6ae..cc417a9d3 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -30,6 +30,8 @@ Bug Fixes: - Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122) - Set requirement numpy<2.0 until PyTorch is compatible (https://github.com/pytorch/pytorch/issues/107302) - Updated DQN optimizer input to only include q_network parameters, removing the target_q_network ones (@corentinlger) +- Fixed ``test_buffers.py::test_device`` which was not actually checking the device of tensors (@rhaps0dy) + `SB3-Contrib`_ ^^^^^^^^^^^^^^ diff --git a/tests/test_buffers.py b/tests/test_buffers.py index da6b44a34..18171dd21 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -139,18 +139,25 @@ def test_device_buffer(replay_buffer_cls, device): # Get data from the buffer if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]: + # get returns an iterator over minibatches data = buffer.get(50) elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]: - data = buffer.sample(50) + data = [buffer.sample(50)] # Check that all data are on the desired device desired_device = get_device(device).type - for value in list(data): - if isinstance(value, dict): - for key in value.keys(): - assert value[key].device.type == desired_device - elif isinstance(value, th.Tensor): - assert value.device.type == desired_device + for minibatch in list(data): + for value in minibatch: + if isinstance(value, dict): + for key in value.keys(): + assert value[key].device.type == desired_device + elif isinstance(value, th.Tensor): + assert value.device.type == desired_device + elif isinstance(value, np.ndarray): + # For prioritized replay weights/indices + pass + else: + raise TypeError(f"Unknown value type: {type(value)}") def test_custom_rollout_buffer(): From 9a3b28bb9f24a1646479500fb23be55ba652a30d Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 23 Aug 2024 08:58:43 +0200 Subject: [PATCH 06/11] [ci skip] Update README.md, fix image display --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 78592bae8..52634e486 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,3 @@ - - ![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg) [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) @@ -8,6 +6,8 @@ # Stable Baselines3 + + Stable Baselines3 (SB3) is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines). You can read a detailed presentation of Stable Baselines3 in the [v1.0 blog post](https://araffin.github.io/post/sb3/) or our [JMLR paper](https://jmlr.org/papers/volume22/20-1364/20-1364.pdf). From 512eea923afad6f6da4bb53d72b6ea4c6d856e59 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 13 Sep 2024 13:15:23 +0200 Subject: [PATCH 07/11] Warn users when using multi-dim MultiDiscrete obs space (#2003) * Update env checker to warn users when using multi-dim MultiDiscrete obs space * Update changelog --- docs/misc/changelog.rst | 10 +++++++++- stable_baselines3/common/env_checker.py | 8 ++++++++ stable_baselines3/version.txt | 2 +- tests/test_envs.py | 2 ++ 4 files changed, 20 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index cc417a9d3..e8a2984d2 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a8 (WIP) +Release 2.4.0a9 (WIP) -------------------------- .. note:: @@ -13,6 +13,13 @@ Release 2.4.0a8 (WIP) To suppress the warning, simply save the model again. You can find more info in `PR #1963 `_ +.. warning:: + + Stable-Baselines3 (SB3) v2.4.0 will be the last one supporting Python 3.8 (end of life in October 2024) + and PyTorch < 2.0. + We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.0. + + Breaking Changes: ^^^^^^^^^^^^^^^^^ @@ -20,6 +27,7 @@ New Features: ^^^^^^^^^^^^^ - Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ) - Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle) +- Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces Bug Fixes: ^^^^^^^^^^ diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 090d609ba..e47dd123a 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -98,6 +98,14 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act "is not supported but `dict(space2=Box(), spaces3=Box(), spaces4=Discrete())` is." ) + if isinstance(observation_space, spaces.MultiDiscrete) and len(observation_space.nvec.shape) > 1: + warnings.warn( + f"The MultiDiscrete observation space uses a multidimensional array {observation_space.nvec} " + "which is currently not supported by Stable-Baselines3. " + "Please convert it to a 1D array using a wrapper: " + "https://github.com/DLR-RM/stable-baselines3/issues/1836." + ) + if isinstance(observation_space, spaces.Tuple): warnings.warn( "The observation space is a Tuple, " diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index ee717ba15..636c433a1 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a8 +2.4.0a9 diff --git a/tests/test_envs.py b/tests/test_envs.py index 9a61eeef0..2fbce120c 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -123,6 +123,8 @@ def patched_step(_action): spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}), # Non zero start index spaces.Discrete(3, start=-1), + # 2D MultiDiscrete + spaces.MultiDiscrete(np.array([[4, 4], [2, 3]])), # Non zero start index (MultiDiscrete) spaces.MultiDiscrete([4, 4], start=[1, 0]), # Non zero start index inside a Dict From 56c153f048f1035f239b77d1569b240ace83c130 Mon Sep 17 00:00:00 2001 From: Devin White <87712306+Dev1nW@users.noreply.github.com> Date: Mon, 7 Oct 2024 04:24:47 -0500 Subject: [PATCH 08/11] Add warning when using PPO on GPU and update doc (#2017) * Update documentation Added comment to PPO documentation that CPU should primarily be used unless using CNN as well as sample code. Added warning to user for both PPO and A2C that CPU should be used if the user is running GPU without using a CNN, reference Issue #1245. * Add warning to base class and add test --------- Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 4 +++- docs/modules/ppo.rst | 17 ++++++++++++++ .../common/on_policy_algorithm.py | 23 +++++++++++++++++++ stable_baselines3/version.txt | 2 +- tests/test_run.py | 14 +++++++++-- 5 files changed, 56 insertions(+), 4 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e8a2984d2..af83d2302 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a9 (WIP) +Release 2.4.0a10 (WIP) -------------------------- .. note:: @@ -60,12 +60,14 @@ Others: - Fixed various typos (@cschindlbeck) - Remove unnecessary SDE noise resampling in PPO update (@brn-dev) - Updated PyTorch version on CI to 2.3.1 +- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy`` Bug Fixes: ^^^^^^^^^^ Documentation: ^^^^^^^^^^^^^^ +- Updated PPO doc to recommend using CPU with ``MlpPolicy`` Release 2.3.2 (2024-04-27) -------------------------- diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index b5e667241..4285cfb50 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -88,6 +88,23 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments. vec_env.render("human") +.. note:: + + PPO is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using ``SubprocVecEnv`` instead of the default ``DummyVecEnv``: + + .. code-block:: + + from stable_baselines3 import PPO + from stable_baselines3.common.env_util import make_vec_env + from stable_baselines3.common.vec_env import SubprocVecEnv + + if __name__=="__main__": + env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv) + model = PPO("MlpPolicy", env, device="cpu") + model.learn(total_timesteps=25_000) + + For more information, see :ref:`Vectorized Environments `, `Issue #1245 `_ or the `Multiprocessing notebook `_. + Results ------- diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 262453721..dc885242e 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -1,5 +1,6 @@ import sys import time +import warnings from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import numpy as np @@ -135,6 +136,28 @@ def _setup_model(self) -> None: self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs ) self.policy = self.policy.to(self.device) + # Warn when not using CPU with MlpPolicy + self._maybe_recommend_cpu() + + def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None: + """ + Recommend to use CPU only when using A2C/PPO with MlpPolicy. + + :param: The name of the class for the default MlpPolicy. + """ + policy_class_name = self.policy_class.__name__ + if self.device != th.device("cpu") and policy_class_name == mlp_class_name: + warnings.warn( + f"You are trying to run {self.__class__.__name__} on the GPU, " + "but it is primarily intended to run on the CPU when not using a CNN policy " + f"(you are using {policy_class_name} which should be a MlpPolicy). " + "See https://github.com/DLR-RM/stable-baselines3/issues/1245 " + "for more info. " + "You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU." + "Note: The model will train, but the GPU utilization will be poor and " + "the training might take longer than on CPU.", + UserWarning, + ) def collect_rollouts( self, diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 636c433a1..852a32b3f 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a9 +2.4.0a10 diff --git a/tests/test_run.py b/tests/test_run.py index 31c7b956e..4acabb692 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,6 +1,7 @@ import gymnasium as gym import numpy as np import pytest +import torch as th from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.env_util import make_vec_env @@ -211,8 +212,11 @@ def test_warn_dqn_multi_env(): def test_ppo_warnings(): - """Test that PPO warns and errors correctly on - problematic rollout buffer sizes""" + """ + Test that PPO warns and errors correctly on + problematic rollout buffer sizes, + and recommend using CPU. + """ # Only 1 step: advantage normalization will return NaN with pytest.raises(AssertionError): @@ -234,3 +238,9 @@ def test_ppo_warnings(): loss = model.logger.name_to_value["train/loss"] assert loss > 0 assert not np.isnan(loss) # check not nan (since nan does not equal nan) + + with pytest.warns(UserWarning, match="You are trying to run PPO on the GPU"): + model = PPO("MlpPolicy", "Pendulum-v1") + # Pretend to be on the GPU + model.device = th.device("cuda") + model._maybe_recommend_cpu() From 3d59b5c86b0d8d61ee4a68cb2ae8743fd178670b Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 24 Oct 2024 22:20:05 +0900 Subject: [PATCH 09/11] Use uv on GitHub CI for faster download and update changelog (#2026) * Use uv on GitHub CI for faster download and update changelog * Fix new mypy issues --- .github/workflows/ci.yml | 11 +++++++---- docs/guide/sb3_contrib.rst | 1 + docs/misc/changelog.rst | 7 +++++++ stable_baselines3/common/utils.py | 4 ++-- tests/test_utils.py | 2 +- 5 files changed, 18 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 822e0cb3f..cb9055266 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,18 +31,21 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip + # Use uv for faster downloads + pip install uv # cpu version of pytorch - pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu + # See https://github.com/astral-sh/uv/issues/1497 + uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu # Install Atari Roms - pip install autorom + uv pip install --system autorom wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz AutoROM --accept-license --source-file Roms.tar.gz - pip install .[extra_no_roms,tests,docs] + uv pip install --system .[extra_no_roms,tests,docs] # Use headless version - pip install opencv-python-headless + uv pip install --system opencv-python-headless - name: Lint with ruff run: | make lint diff --git a/docs/guide/sb3_contrib.rst b/docs/guide/sb3_contrib.rst index 445832c59..8ec912e15 100644 --- a/docs/guide/sb3_contrib.rst +++ b/docs/guide/sb3_contrib.rst @@ -42,6 +42,7 @@ See documentation for the full list of included features. - `PPO with recurrent policy (RecurrentPPO aka PPO LSTM) `_ - `Truncated Quantile Critics (TQC)`_ - `Trust Region Policy Optimization (TRPO) `_ +- `Batch Normalization in Deep Reinforcement Learning (CrossQ) `_ **Gym Wrappers**: diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index af83d2302..2c0974ac2 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -6,6 +6,8 @@ Changelog Release 2.4.0a10 (WIP) -------------------------- +**New algorithm: CrossQ in SB3 Contrib** + .. note:: DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about @@ -43,6 +45,10 @@ Bug Fixes: `SB3-Contrib`_ ^^^^^^^^^^^^^^ +- Added ``CrossQ`` algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen) +- Added ``BatchRenorm`` PyTorch layer used in ``CrossQ`` (@danielpalen) +- Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger) +- Fixed loading QRDQN changes `target_update_interval` (@jak3122) `RL Zoo`_ ^^^^^^^^^ @@ -61,6 +67,7 @@ Others: - Remove unnecessary SDE noise resampling in PPO update (@brn-dev) - Updated PyTorch version on CI to 2.3.1 - Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy`` +- Switched to uv to download packages faster on GitHub CI Bug Fixes: ^^^^^^^^^^ diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index bcde1cfa0..4e9fbc2db 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -46,7 +46,7 @@ def set_random_seed(seed: int, using_cuda: bool = False) -> None: # From stable baselines -def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: +def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> float: """ Computes fraction of variance that ypred explains about y. Returns 1 - Var[y-ypred] / Var[y] @@ -62,7 +62,7 @@ def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: """ assert y_true.ndim == 1 and y_pred.ndim == 1 var_y = np.var(y_true) - return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + return np.nan if var_y == 0 else float(1 - np.var(y_true - y_pred) / var_y) def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> None: diff --git a/tests/test_utils.py b/tests/test_utils.py index 4cc8b7e9f..81f134168 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -177,7 +177,7 @@ def test_custom_vec_env(tmp_path): @pytest.mark.parametrize("direct_policy", [False, True]) -def test_evaluate_policy(direct_policy: bool): +def test_evaluate_policy(direct_policy): model = A2C("MlpPolicy", "Pendulum-v1", seed=0) n_steps_per_episode, n_eval_episodes = 200, 2 From dd3d0acf154dec2b8a9a92fcc5fb83e4a05eaf72 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 29 Oct 2024 12:23:13 +0100 Subject: [PATCH 10/11] Update readme and clarify planned features (#2030) * Update readme and clarify planned features * Fix rtd python version * Fix pip version for rtd * Update rtd ubuntu and mambaforge * Add upper bound for gymnasium * [ci skip] Update readme --- .readthedocs.yml | 4 ++-- CONTRIBUTING.md | 2 +- README.md | 32 +++++++++++++++++++++----------- docs/conda_env.yml | 12 ++++++------ docs/guide/algos.rst | 1 + docs/index.rst | 4 +++- docs/misc/changelog.rst | 2 ++ 7 files changed, 36 insertions(+), 21 deletions(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index dbb2fad03..26f0c883b 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -16,6 +16,6 @@ conda: environment: docs/conda_env.yml build: - os: ubuntu-22.04 + os: ubuntu-24.04 tools: - python: "mambaforge-22.9" + python: "mambaforge-23.11" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d295269a9..cc5d1075b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -6,7 +6,7 @@ into two categories: - Create an issue about your intended feature, and we shall discuss the design and implementation. Once we agree that the plan looks good, go ahead and implement it. 2. You want to implement a feature or bug-fix for an outstanding issue - - Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/issues + - Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted - Pick an issue or feature and comment on the task that you want to work on this feature. - If you need more context on a particular issue, please ask, and we shall provide. diff --git a/README.md b/README.md index 52634e486..5d25781d9 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg) -[![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) +[![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml) +[![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml) [![codestyle](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) @@ -22,6 +22,8 @@ These algorithms will make it easier for the research community and industry to **The performance of each algorithm was tested** (see *Results* section in their respective page), you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details. +We also provide detailed logs and reports on the [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sb3) platform. + | **Features** | **Stable-Baselines3** | | --------------------------- | ----------------------| @@ -41,7 +43,13 @@ you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselin ### Planned features -Please take a look at the [Roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) and [Milestones](https://github.com/DLR-RM/stable-baselines3/milestones). +Since most of the features from the [original roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) have been implemented, there are no major changes planned for SB3, it is now *stable*. +If you want to contribute, you can search in the issues for the ones where [help is welcomed](https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted) and the other [proposed enhancements](https://github.com/DLR-RM/stable-baselines3/labels/enhancement). + +While SB3 development is now focused on bug fixes and maintenance (doc update, user experience, ...), there is more active development going on in the associated repositories: +- newer algorithms are regularly added to the [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) repository +- faster variants are developed in the [SBX (SB3 + Jax)](https://github.com/araffin/sbx) repository +- the training framework for SB3, the RL Zoo, has an active [roadmap](https://github.com/DLR-RM/rl-baselines3-zoo/issues/299) ## Migration guide: from Stable-Baselines (SB2) to Stable-Baselines3 (SB3) @@ -79,7 +87,7 @@ Documentation: https://rl-baselines3-zoo.readthedocs.io/en/master/ We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) -This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO). +This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), CrossQ, Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO). Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/) @@ -97,17 +105,16 @@ It provides a minimal number of features compared to SB3 but can be much faster ### Prerequisites Stable Baselines3 requires Python 3.8+. -#### Windows 10 +#### Windows To install stable-baselines on Windows, please look at the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/install.html#prerequisites). ### Install using pip Install the Stable Baselines3 package: +```sh +pip install 'stable-baselines3[extra]' ``` -pip install stable-baselines3[extra] -``` -**Note:** Some shells such as Zsh require quotation marks around brackets, i.e. `pip install 'stable-baselines3[extra]'` ([More Info](https://stackoverflow.com/a/30539963)). This includes an optional dependencies like Tensorboard, OpenCV or `ale-py` to train on atari games. If you do not need those, you can use: ```sh @@ -177,6 +184,7 @@ All the following examples can be executed online using Google Colab notebooks: | ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- | | ARS[1](#f1) | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | | A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| CrossQ[1](#f1) | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | | DDPG | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | | DQN | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | | HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | @@ -191,7 +199,7 @@ All the following examples can be executed online using Google Colab notebooks: 1: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) GitHub repository. -Actions `gym.spaces`: +Actions `gymnasium.spaces`: * `Box`: A N-dimensional box that contains every point in the action space. * `Discrete`: A list of possible actions, where each timestep only one of the actions can be used. * `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used. @@ -218,9 +226,9 @@ To run a single test: python3 -m pytest -v -k 'test_check_env_dict_action' ``` -You can also do a static type check using `pytype` and `mypy`: +You can also do a static type check using `mypy`: ```sh -pip install pytype mypy +pip install mypy make type ``` @@ -252,6 +260,8 @@ To cite this repository in publications: } ``` +Note: If you need to refer to a specific version of SB3, you can also use the [Zenodo DOI](https://doi.org/10.5281/zenodo.8123988). + ## Maintainers Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave), [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli) and [Quentin Gallouédec](https://gallouedec.com/) (@qgallouedec). diff --git a/docs/conda_env.yml b/docs/conda_env.yml index 53fecf278..e025a57e1 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -1,18 +1,18 @@ name: root channels: - pytorch - - defaults + - conda-forge dependencies: - cpuonly=1.0=0 - - pip=22.3.1 - - python=3.8 - - pytorch=1.13.0=py3.8_cpu_0 + - pip=24.2 + - python=3.11 + - pytorch=2.5.0=py3.11_cpu_0 - pip: - - gymnasium + - gymnasium>=0.28.1,<0.30 - cloudpickle - opencv-python-headless - pandas - - numpy + - numpy>=1.20,<2.0 - matplotlib - sphinx>=5,<8 - sphinx_rtd_theme>=1.3.0 diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index d5e7ae1d2..db03ba292 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -10,6 +10,7 @@ Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` =================== =========== ============ ================= =============== ================ ARS [#f1]_ ✔️ ✔️ ❌ ❌ ✔️ A2C ✔️ ✔️ ✔️ ✔️ ✔️ +CrossQ [#f1]_ ✔️ ❌ ❌ ❌ ✔️ DDPG ✔️ ❌ ❌ ❌ ✔️ DQN ❌ ✔️ ❌ ❌ ✔️ HER ✔️ ✔️ ❌ ❌ ✔️ diff --git a/docs/index.rst b/docs/index.rst index c8a70a94b..d74120c41 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -113,12 +113,14 @@ To cite this project in publications: url = {http://jmlr.org/papers/v22/20-1364.html} } +Note: If you need to refer to a specific version of SB3, you can also use the `Zenodo DOI `_. + Contributing ------------ To any interested in making the rl baselines better, there are still some improvements that need to be done. -You can check issues in the `repo `_. +You can check issues in the `repository `_. If you want to contribute, please read `CONTRIBUTING.md `_ first. diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 2c0974ac2..b32cd7ce1 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -68,6 +68,7 @@ Others: - Updated PyTorch version on CI to 2.3.1 - Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy`` - Switched to uv to download packages faster on GitHub CI +- Updated dependencies for read the doc Bug Fixes: ^^^^^^^^^^ @@ -75,6 +76,7 @@ Bug Fixes: Documentation: ^^^^^^^^^^^^^^ - Updated PPO doc to recommend using CPU with ``MlpPolicy`` +- Clarified documentation about planned features and citing software Release 2.3.2 (2024-04-27) -------------------------- From 8f0b488bc5a897f1ac2b95f493bcb6b7e92d311c Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Mon, 4 Nov 2024 11:03:12 +0000 Subject: [PATCH 11/11] Update Gymnasium to v1.0.0 (#1837) * Update Gymnasium to v1.0.0a1 * Comment out `gymnasium.wrappers.monitor` (todo update to VideoRecord) * Fix ruff warnings * Register Atari envs * Update `getattr` to `Env.get_wrapper_attr` * Reorder imports * Fix `seed` order * Fix collecting `max_steps` * Copy and paste video recorder to prevent the need to rewrite the vec vide recorder wrapper * Use `typing.List` rather than list * Fix env attribute forwarding * Separate out env attribute collection from its utilisation * Update for Gymnasium alpha 2 * Remove assert for OrderedDict * Update setup.py * Add type: ignore * Test with Gymnasium main * Remove `gymnasium.logger.debug/info` * Fix github CI yaml * Run gym 0.29.1 on python 3.10 * Update lower bounds * Integrate video recorder * Remove ordered dict * Update changelog --------- Co-authored-by: Antonin RAFFIN --- .github/workflows/ci.yml | 20 ++-- docs/conda_env.yml | 2 +- docs/misc/changelog.rst | 7 +- pyproject.toml | 1 - setup.py | 43 +++---- .../common/vec_env/dummy_vec_env.py | 8 +- stable_baselines3/common/vec_env/patch_gym.py | 2 +- .../common/vec_env/subproc_vec_env.py | 34 +++--- stable_baselines3/common/vec_env/util.py | 20 +--- .../common/vec_env/vec_video_recorder.py | 109 ++++++++++++------ stable_baselines3/version.txt | 2 +- tests/test_dict_env.py | 3 +- tests/test_gae.py | 2 +- tests/test_logger.py | 10 +- tests/test_utils.py | 3 + tests/test_vec_envs.py | 2 +- 16 files changed, 148 insertions(+), 120 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cb9055266..d34a93c9a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,12 @@ jobs: strategy: matrix: python-version: ["3.8", "3.9", "3.10", "3.11"] - + include: + # Default version + - gymnasium-version: "1.0.0" + # Add a new config to test gym<1.0 + - python-version: "3.10" + gymnasium-version: "0.29.1" steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -37,15 +42,14 @@ jobs: # See https://github.com/astral-sh/uv/issues/1497 uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu - # Install Atari Roms - uv pip install --system autorom - wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 - base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz - AutoROM --accept-license --source-file Roms.tar.gz - - uv pip install --system .[extra_no_roms,tests,docs] + uv pip install --system .[extra,tests,docs] # Use headless version uv pip install --system opencv-python-headless + - name: Install specific version of gym + run: | + uv pip install --system gymnasium==${{ matrix.gymnasium-version }} + # Only run for python 3.10, downgrade gym to 0.29.1 + if: matrix.gymnasium-version != '1.0.0' - name: Lint with ruff run: | make lint diff --git a/docs/conda_env.yml b/docs/conda_env.yml index e025a57e1..ac065b3b9 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -8,7 +8,7 @@ dependencies: - python=3.11 - pytorch=2.5.0=py3.11_cpu_0 - pip: - - gymnasium>=0.28.1,<0.30 + - gymnasium>=0.29.1,<1.1.0 - cloudpickle - opencv-python-headless - pandas diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b32cd7ce1..cf2a2a520 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,10 +3,10 @@ Changelog ========== -Release 2.4.0a10 (WIP) +Release 2.4.0a11 (WIP) -------------------------- -**New algorithm: CrossQ in SB3 Contrib** +**New algorithm: CrossQ in SB3 Contrib, Gymnasium v1.0 support** .. note:: @@ -24,12 +24,14 @@ Release 2.4.0a10 (WIP) Breaking Changes: ^^^^^^^^^^^^^^^^^ +- Increase minimum required version of Gymnasium to 0.29.1 New Features: ^^^^^^^^^^^^^ - Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ) - Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle) - Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces +- Added support for Gymnasium v1.0 Bug Fixes: ^^^^^^^^^^ @@ -69,6 +71,7 @@ Others: - Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy`` - Switched to uv to download packages faster on GitHub CI - Updated dependencies for read the doc +- Removed unnecessary ``copy_obs_dict`` method for ``SubprocVecEnv``, remove the use of ordered dict and rename ``flatten_obs`` to ``stack_obs`` Bug Fixes: ^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index dd435a33e..1fd1a1890 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,6 @@ ignore = ["B028", "RUF013"] # ClassVar, implicit optional check not needed for tests "./tests/*.py" = ["RUF012", "RUF013"] - [tool.ruff.lint.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 15 diff --git a/setup.py b/setup.py index 9d56dfd77..52f626462 100644 --- a/setup.py +++ b/setup.py @@ -70,37 +70,13 @@ """ # noqa:E501 -# Atari Games download is sometimes problematic: -# https://github.com/Farama-Foundation/AutoROM/issues/39 -# That's why we define extra packages without it. -extra_no_roms = [ - # For render - "opencv-python", - "pygame", - # Tensorboard support - "tensorboard>=2.9.1", - # Checking memory taken by replay buffer - "psutil", - # For progress bar callback - "tqdm", - "rich", - # For atari games, - "shimmy[atari]~=1.3.0", - "pillow", -] - -extra_packages = extra_no_roms + [ # noqa: RUF005 - # For atari roms, - "autorom[accept-rom-license]~=0.6.1", -] - setup( name="stable_baselines3", packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gymnasium>=0.28.1,<0.30", + "gymnasium>=0.29.1,<1.1.0", "numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302 "torch>=1.13", # For saving models @@ -133,8 +109,21 @@ # Copy button for code snippets "sphinx_copybutton", ], - "extra": extra_packages, - "extra_no_roms": extra_no_roms, + "extra": [ + # For render + "opencv-python", + "pygame", + # Tensorboard support + "tensorboard>=2.9.1", + # Checking memory taken by replay buffer + "psutil", + # For progress bar callback + "tqdm", + "rich", + # For atari games, + "ale-py>=0.9.0", + "pillow", + ], }, description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.", author="Antonin Raffin", diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 15ecfb681..5625e2453 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -8,7 +8,7 @@ from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn from stable_baselines3.common.vec_env.patch_gym import _patch_env -from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info +from stable_baselines3.common.vec_env.util import dict_to_obs, obs_space_info class DummyVecEnv(VecEnv): @@ -110,12 +110,12 @@ def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None: self.buf_obs[key][env_idx] = obs[key] # type: ignore[call-overload] def _obs_from_buf(self) -> VecEnvObs: - return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs)) + return dict_to_obs(self.observation_space, deepcopy(self.buf_obs)) def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: """Return attribute from vectorized environment (see base class).""" target_envs = self._get_target_envs(indices) - return [getattr(env_i, attr_name) for env_i in target_envs] + return [env_i.get_wrapper_attr(attr_name) for env_i in target_envs] def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: """Set attribute inside vectorized environments (see base class).""" @@ -126,7 +126,7 @@ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: """Call instance methods of vectorized environments.""" target_envs = self._get_target_envs(indices) - return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs] + return [env_i.get_wrapper_attr(method_name)(*method_args, **method_kwargs) for env_i in target_envs] def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: """Check if worker environments are wrapped with a given wrapper""" diff --git a/stable_baselines3/common/vec_env/patch_gym.py b/stable_baselines3/common/vec_env/patch_gym.py index 6ba655ebf..874809a03 100644 --- a/stable_baselines3/common/vec_env/patch_gym.py +++ b/stable_baselines3/common/vec_env/patch_gym.py @@ -43,7 +43,7 @@ def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma "Missing shimmy installation. You provided an OpenAI Gym environment. " "Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. " "In order to use OpenAI Gym environments with SB3, you need to " - "install shimmy (`pip install 'shimmy>=0.2.1'`)." + "install shimmy (`pip install 'shimmy>=2.0'`)." ) from e warnings.warn( diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index c598c735a..a606a7cb9 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -1,6 +1,5 @@ import multiprocessing as mp import warnings -from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union import gymnasium as gym @@ -54,10 +53,10 @@ def _worker( elif cmd == "get_spaces": remote.send((env.observation_space, env.action_space)) elif cmd == "env_method": - method = getattr(env, data[0]) + method = env.get_wrapper_attr(data[0]) remote.send(method(*data[1], **data[2])) elif cmd == "get_attr": - remote.send(getattr(env, data)) + remote.send(env.get_wrapper_attr(data)) elif cmd == "set_attr": remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value] elif cmd == "is_wrapped": @@ -129,7 +128,7 @@ def step_wait(self) -> VecEnvStepReturn: results = [remote.recv() for remote in self.remotes] self.waiting = False obs, rews, dones, infos, self.reset_infos = zip(*results) # type: ignore[assignment] - return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value] + return _stack_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value] def reset(self) -> VecEnvObs: for env_idx, remote in enumerate(self.remotes): @@ -139,7 +138,7 @@ def reset(self) -> VecEnvObs: # Seeds and options are only used once self._reset_seeds() self._reset_options() - return _flatten_obs(obs, self.observation_space) + return _stack_obs(obs, self.observation_space) def close(self) -> None: if self.closed: @@ -206,27 +205,28 @@ def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]: return [self.remotes[i] for i in indices] -def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs: +def _stack_obs(obs_list: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs: """ - Flatten observations, depending on the observation space. + Stack observations (convert from a list of single env obs to a stack of obs), + depending on the observation space. :param obs: observations. A list or tuple of observations, one per environment. Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays. - :return: flattened observations. - A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays. + :return: Concatenated observations. + A NumPy array or a dict or tuple of stacked numpy arrays. Each NumPy array has the environment index as its first axis. """ - assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" - assert len(obs) > 0, "need observations from at least one environment" + assert isinstance(obs_list, (list, tuple)), "expected list or tuple of observations per environment" + assert len(obs_list) > 0, "need observations from at least one environment" if isinstance(space, spaces.Dict): - assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces" - assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" - return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) + assert isinstance(space.spaces, dict), "Dict space must have ordered subspaces" + assert isinstance(obs_list[0], dict), "non-dict observation for environment with Dict observation space" + return {key: np.stack([single_obs[key] for single_obs in obs_list]) for key in space.spaces.keys()} # type: ignore[call-overload] elif isinstance(space, spaces.Tuple): - assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" + assert isinstance(obs_list[0], tuple), "non-tuple observation for environment with Tuple observation space" obs_len = len(space.spaces) - return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) # type: ignore[index] + return tuple(np.stack([single_obs[i] for single_obs in obs_list]) for i in range(obs_len)) # type: ignore[index] else: - return np.stack(obs) # type: ignore[arg-type] + return np.stack(obs_list) # type: ignore[arg-type] diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 855f50edc..6ea04f6ab 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -2,7 +2,6 @@ Helpers for dealing with vectorized environments. """ -from collections import OrderedDict from typing import Any, Dict, List, Tuple import numpy as np @@ -12,17 +11,6 @@ from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs -def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: - """ - Deep-copy a dict of numpy arrays. - - :param obs: a dict of numpy arrays. - :return: a dict of copied numpy arrays. - """ - assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'" - return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) - - def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs: """ Convert an internal representation raw_obs into the appropriate type @@ -60,13 +48,13 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[ """ check_for_nested_spaces(obs_space) if isinstance(obs_space, spaces.Dict): - assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces" + assert isinstance(obs_space.spaces, dict), "Dict space must have ordered subspaces" subspaces = obs_space.spaces elif isinstance(obs_space, spaces.Tuple): - subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment] + subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment,misc] else: assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'" - subspaces = {None: obs_space} # type: ignore[assignment] + subspaces = {None: obs_space} # type: ignore[assignment,dict-item] keys = [] shapes = {} dtypes = {} @@ -74,4 +62,4 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[ keys.append(key) shapes[key] = box.shape dtypes[key] = box.dtype - return keys, shapes, dtypes + return keys, shapes, dtypes # type: ignore[return-value] diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index 52faebd1f..e586f94ab 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -1,7 +1,9 @@ import os -from typing import Callable +import os.path +from typing import Callable, List -from gymnasium.wrappers.monitoring import video_recorder +import numpy as np +from gymnasium import error, logger from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv @@ -13,6 +15,11 @@ class VecVideoRecorder(VecEnvWrapper): Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video. It requires ffmpeg or avconv to be installed on the machine. + Note: for now it only allows to record one video and all videos + must have at least two frames. + + The video recorder code was adapted from Gymnasium v1.0. + :param venv: :param video_folder: Where to save videos :param record_video_trigger: Function that defines when to start recording. @@ -22,8 +29,6 @@ class VecVideoRecorder(VecEnvWrapper): :param name_prefix: Prefix to the video name """ - video_recorder: video_recorder.VideoRecorder - def __init__( self, venv: VecEnv, @@ -51,6 +56,8 @@ def __init__( self.env.metadata = metadata assert self.env.render_mode == "rgb_array", f"The render_mode must be 'rgb_array', not {self.env.render_mode}" + self.frames_per_sec = self.env.metadata.get("render_fps", 30) + self.record_video_trigger = record_video_trigger self.video_folder = os.path.abspath(video_folder) # Create output folder if needed @@ -60,54 +67,88 @@ def __init__( self.step_id = 0 self.video_length = video_length + self.video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}.mp4" + self.video_path = os.path.join(self.video_folder, self.video_name) + self.recording = False - self.recorded_frames = 0 + self.recorded_frames: list[np.ndarray] = [] + + try: + import moviepy # noqa: F401 + except ImportError as e: + raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install 'gymnasium[other]'`") from e def reset(self) -> VecEnvObs: obs = self.venv.reset() - self.start_video_recorder() + if self._video_enabled(): + self._start_video_recorder() return obs - def start_video_recorder(self) -> None: - self.close_video_recorder() - - video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}" - base_path = os.path.join(self.video_folder, video_name) - self.video_recorder = video_recorder.VideoRecorder( - env=self.env, base_path=base_path, metadata={"step_id": self.step_id} - ) - - self.video_recorder.capture_frame() - self.recorded_frames = 1 - self.recording = True + def _start_video_recorder(self) -> None: + self._start_recording() + self._capture_frame() def _video_enabled(self) -> bool: return self.record_video_trigger(self.step_id) def step_wait(self) -> VecEnvStepReturn: - obs, rews, dones, infos = self.venv.step_wait() + obs, rewards, dones, infos = self.venv.step_wait() self.step_id += 1 if self.recording: - self.video_recorder.capture_frame() - self.recorded_frames += 1 - if self.recorded_frames > self.video_length: - print(f"Saving video to {self.video_recorder.path}") - self.close_video_recorder() + self._capture_frame() + if len(self.recorded_frames) > self.video_length: + print(f"Saving video to {self.video_path}") + self._stop_recording() elif self._video_enabled(): - self.start_video_recorder() + self._start_video_recorder() - return obs, rews, dones, infos + return obs, rewards, dones, infos - def close_video_recorder(self) -> None: - if self.recording: - self.video_recorder.close() - self.recording = False - self.recorded_frames = 1 + def _capture_frame(self) -> None: + assert self.recording, "Cannot capture a frame, recording wasn't started." + + frame = self.env.render() + if isinstance(frame, List): + frame = frame[-1] + + if isinstance(frame, np.ndarray): + self.recorded_frames.append(frame) + else: + self._stop_recording() + logger.warn( + f"Recording stopped: expected type of frame returned by render to be a numpy array, got instead {type(frame)}." + ) def close(self) -> None: + """Closes the wrapper then the video recorder.""" VecEnvWrapper.close(self) - self.close_video_recorder() + if self.recording: + self._stop_recording() + + def _start_recording(self) -> None: + """Start a new recording. If it is already recording, stops the current recording before starting the new one.""" + if self.recording: + self._stop_recording() + + self.recording = True + + def _stop_recording(self) -> None: + """Stop current recording and saves the video.""" + assert self.recording, "_stop_recording was called, but no recording was started" + + if len(self.recorded_frames) == 0: + logger.warn("Ignored saving a video as there were zero frames to save.") + else: + from moviepy.video.io.ImageSequenceClip import ImageSequenceClip + + clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec) + clip.write_videofile(self.video_path) + + self.recorded_frames = [] + self.recording = False - def __del__(self): - self.close_video_recorder() + def __del__(self) -> None: + """Warn the user in case last video wasn't saved.""" + if len(self.recorded_frames) > 0: + logger.warn("Unable to save last video! Did you call close()?") diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 852a32b3f..d5cafdb5a 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a10 +2.4.0a11 diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index f093e47e7..8049c6887 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -117,12 +117,11 @@ def test_consistency(model_class): """ use_discrete_actions = model_class == DQN dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True) + dict_env.seed(10) dict_env = gym.wrappers.TimeLimit(dict_env, 100) env = gym.wrappers.FlattenObservation(dict_env) - dict_env.seed(10) obs, _ = dict_env.reset() - kwargs = {} n_steps = 256 if model_class in {A2C, PPO}: diff --git a/tests/test_gae.py b/tests/test_gae.py index 83b95a4c0..bb674cffa 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -73,7 +73,7 @@ def _on_rollout_end(self): buffer = self.model.rollout_buffer rollout_size = buffer.size() - max_steps = self.training_env.envs[0].max_steps + max_steps = self.training_env.envs[0].get_wrapper_attr("max_steps") gamma = self.model.gamma gae_lambda = self.model.gae_lambda value = self.model.policy.constant_value diff --git a/tests/test_logger.py b/tests/test_logger.py index bc18bf2ce..02d36b306 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -592,6 +592,7 @@ def test_rollout_success_rate_onpolicy_algo(tmp_path): """ STATS_WINDOW_SIZE = 10 + # Add dummy successes with 0.3, 0.5 and 0.8 success_rate of length STATS_WINDOW_SIZE dummy_successes = [ [True] * 3 + [False] * 7, @@ -603,16 +604,17 @@ def test_rollout_success_rate_onpolicy_algo(tmp_path): # Monitor the env to track the success info monitor_file = str(tmp_path / "monitor.csv") env = Monitor(DummySuccessEnv(dummy_successes, ep_steps), filename=monitor_file, info_keywords=("is_success",)) + steps_per_log = env.unwrapped.steps_per_log # Equip the model of a custom logger to check the success_rate info - model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=env.steps_per_log, verbose=1) + model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=steps_per_log, verbose=1) logger = InMemoryLogger() model.set_logger(logger) # Make the model learn and check that the success rate corresponds to the ratio of dummy successes - model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.3 - model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.5 - model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.8 diff --git a/tests/test_utils.py b/tests/test_utils.py index 81f134168..bb2ebd067 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,7 @@ import os import shutil +import ale_py import gymnasium as gym import numpy as np import pytest @@ -24,6 +25,8 @@ ) from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv +gym.register_envs(ale_py) + @pytest.mark.parametrize("env_id", ["CartPole-v1", lambda: gym.make("CartPole-v1")]) @pytest.mark.parametrize("n_envs", [1, 2]) diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index a9516ae25..3aa52762d 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -307,7 +307,7 @@ def test_vecenv_dict_spaces(vec_env_class): space = spaces.Dict(SPACES) def obs_assert(obs): - assert isinstance(obs, collections.OrderedDict) + assert isinstance(obs, dict) assert obs.keys() == space.spaces.keys() for key, values in obs.items(): check_vecenv_obs(values, space.spaces[key])