From a653aec10d8cb4b82443f818fdde123f06c95f75 Mon Sep 17 00:00:00 2001 From: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Date: Wed, 10 Jan 2024 14:46:40 +0100 Subject: [PATCH 01/30] Docs: Env attributes should be modified using env setters (#1789) * add: paragraph on how to modify vec envs attributes via setters (solves DLR-RM#1573) * Update vec env doc * Update callback doc and SB3 version * Fix indentation --------- Co-authored-by: Antonin Raffin --- docs/guide/callbacks.rst | 15 ++++--- docs/guide/vec_envs.rst | 84 +++++++++++++++++++++++++++++++++++ docs/misc/changelog.rst | 35 ++++++++++++++- stable_baselines3/version.txt | 2 +- 4 files changed, 127 insertions(+), 9 deletions(-) diff --git a/docs/guide/callbacks.rst b/docs/guide/callbacks.rst index 239966a6f..472f42114 100644 --- a/docs/guide/callbacks.rst +++ b/docs/guide/callbacks.rst @@ -29,24 +29,25 @@ You can find two examples of custom callbacks in the documentation: one for savi :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ - def __init__(self, verbose=0): + def __init__(self, verbose: int = 0): super().__init__(verbose) # Those variables will be accessible in the callback # (they are defined in the base class) # The RL model # self.model = None # type: BaseAlgorithm # An alias for self.model.get_env(), the environment used for training - # self.training_env = None # type: Union[gym.Env, VecEnv, None] + # self.training_env # type: VecEnv # Number of time the callback was called # self.n_calls = 0 # type: int + # num_timesteps = n_envs * n times env.step() was called # self.num_timesteps = 0 # type: int # local and global variables - # self.locals = None # type: Dict[str, Any] - # self.globals = None # type: Dict[str, Any] + # self.locals = {} # type: Dict[str, Any] + # self.globals = {} # type: Dict[str, Any] # The logger object, used to report things in the terminal - # self.logger = None # stable_baselines3.common.logger - # # Sometimes, for event callback, it is useful - # # to have access to the parent object + # self.logger # type: stable_baselines3.common.logger.Logger + # Sometimes, for event callback, it is useful + # to have access to the parent object # self.parent = None # type: Optional[BaseCallback] def _on_training_start(self) -> None: diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst index 792fedecb..c04001c7c 100644 --- a/docs/guide/vec_envs.rst +++ b/docs/guide/vec_envs.rst @@ -96,6 +96,90 @@ SB3 VecEnv API is actually close to Gym 0.21 API but differs to Gym 0.26+ API: ``vec_env.env_method("method_name", args1, args2, kwargs1=kwargs1)`` and ``vec_env.set_attr("attribute_name", new_value)``. +Modifying Vectorized Environments Attributes +-------------------------------------------- + +If you plan to `modify the attributes of an environment `_ while it is used (e.g., modifying an attribute specifying the task carried out for a portion of training when doing multi-task learning, or +a parameter of the environment dynamics), you must expose a setter method. +In fact, directly accessing the environment attribute in the callback can lead to unexpected behavior because environments can be wrapped (using gym or VecEnv wrappers, the ``Monitor`` wrapper being one example). + +Consider the following example for a custom env: + +.. code-block:: python + + import gymnasium as gym + from gymnasium import spaces + + from stable_baselines3.common.env_util import make_vec_env + + + class MyMultiTaskEnv(gym.Env): + + def __init__(self): + super().__init__() + """ + A state and action space for robotic locomotion. + The multi-task twist is that the policy would need to adapt to different terrains, each with its own + friction coefficient, mu. + The friction coefficient is the only parameter that changes between tasks. + mu is a scalar between 0 and 1, and during training a callback is used to update mu. + """ + ... + + def step(self, action): + # Do something, depending on the action and current value of mu the next state is computed + return self._get_obs(), reward, done, truncated, info + + def set_mu(self, new_mu: float) -> None: + # Note: this value should be used only at the next reset + self.mu = new_mu + + # Example of wrapped env + # env is of type >>>> + env = gym.make("CartPole-v1") + # To access the base env, without wrapper, you should use `.unwrapped` + # or env.get_wrapper_attr("gravity") to include wrappers + env.unwrapped.gravity + # SB3 uses VecEnv for training, where `env.unwrapped.x = new_value` cannot be used to set an attribute + # therefore, you should expose a setter like `set_mu` to properly set an attribute + vec_env = make_vec_env(MyMultiTaskEnv) + # Print current mu value + # Note: you should use vec_env.env_method("get_wrapper_attr", "mu") in Gymnasium v1.0 + print(vec_env.env_method("get_wrapper_attr", "mu")) + # Change `mu` attribute via the setter + vec_env.env_method("set_mu", "mu", 0.1) + + +In this example ``env.mu`` cannot be accessed/changed directly because it is wrapped in a ``VecEnv`` and because it could be wrapped with other wrappers (see `GH#1573 `_ for a longer explanation). +Instead, the callback should use the ``set_mu`` method via the ``env_method`` method for Vectorized Environments. + +.. code-block:: python + + from itertools import cycle + + class ChangeMuCallback(BaseCallback): + """ + This callback changes the value of mu during training looping + through a list of values until training is aborted. + The environment is implemented so that the impact of changing + the value of mu mid-episode is visible only after the episode is over + and the reset method has been called. + """" + def __init__(self): + super().__init__() + # An iterator that contains the different of the friction coefficient + self.mus = cycle([0.1, 0.2, 0.5, 0.13, 0.9]) + + def _on_step(self): + # Note: in practice, you should not change this value at every step + # but rather depending on some events/metrics like agent performance/episode termination + # both accessible via the `self.logger` or `self.locals` variables + self.training_env.env_method("set_mu", next(self.mus)) + +This callback can then be used to safely modify environment attributes during training since +it calls the environment setter method. + + Vectorized Environments Wrappers -------------------------------- diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 647a0e89e..cbfe41f9d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,39 @@ Changelog ========== + +Release 2.3.0a0 (WIP) +-------------------------- + +Breaking Changes: +^^^^^^^^^^^^^^^^^ + +New Features: +^^^^^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ + +`SB3-Contrib`_ +^^^^^^^^^^^^^^ + +`RL Zoo`_ +^^^^^^^^^ + +`SBX`_ (SB3 + Jax) +^^^^^^^^^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ + +Documentation: +^^^^^^^^^^^^^^ +- Added a paragraph on modifying vectorized environment parameters via setters (@fracapuano) +- Updated callback code example + Release 2.2.1 (2023-11-17) -------------------------- **Support for options at reset, bug fixes and better error messages** @@ -1490,7 +1523,7 @@ And all the contributors: @flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3 @tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio @diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray -@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @JadenTravnik @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn +@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @fracapuano @JadenTravnik @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn @ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan @benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc @wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index c043eea77..00b35529e 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.2.1 +2.3.0a0 From a9273f968eaf8c6e04302a07d803eebfca6e7e86 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 12 Jan 2024 16:05:14 +0100 Subject: [PATCH 02/30] Update TD3/DDPG/DQN defaults for consistency (#1785) * Update TD3/DDPG/DQN defaults for consistency * Update changelog --- docs/misc/changelog.rst | 28 ++++++++++++++++++++++++++-- stable_baselines3/ddpg/ddpg.py | 6 +++--- stable_baselines3/dqn/dqn.py | 2 +- stable_baselines3/td3/td3.py | 6 +++--- stable_baselines3/version.txt | 2 +- 5 files changed, 34 insertions(+), 10 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index cbfe41f9d..a4d8e6373 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,12 +3,36 @@ Changelog ========== - -Release 2.3.0a0 (WIP) +Release 2.3.0a1 (WIP) -------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ +- The defaults hyperparameters of ``TD3`` and ``DDPG`` have been changed to be more consistent with ``SAC`` + +.. code-block:: python + + # SB3 < 2.3.0 default hyperparameters + # model = TD3("MlpPolicy", env, train_freq=(1, "episode"), gradient_steps=-1, batch_size=100) + # SB3 >= 2.3.0: + model = TD3("MlpPolicy", env, train_freq=1, gradient_steps=1, batch_size=256) + +.. note:: + + Two inconsistencies remains: the default network architecture for ``TD3/DDPG`` is ``[400, 300]`` instead of ``[256, 256]`` for SAC (for backward compatibility reasons, see `report on the influence of the network size `_) and the default learning rate is 1e-3 instead of 3e-4 for SAC (for performance reasons, see `W&B report on the influence of the lr `_) + + + +- The default ``leanrning_starts`` parameter of ``DQN`` have been changed to be consistent with the other offpolicy algorithms + + +.. code-block:: python + + # SB3 < 2.3.0 default hyperparameters, 50_000 corresponded to Atari defaults hyperparameters + # model = DQN("MlpPolicy", env, learning_start=50_000) + # SB3 >= 2.3.0: + model = DQN("MlpPolicy", env, learning_start=100) + New Features: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index c311b2357..2fe2fdfc4 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -60,11 +60,11 @@ def __init__( learning_rate: Union[float, Schedule] = 1e-3, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, - batch_size: int = 100, + batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = (1, "episode"), - gradient_steps: int = -1, + train_freq: Union[int, Tuple[int, str]] = 1, + gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 42e3d0df0..894ed9f04 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -79,7 +79,7 @@ def __init__( env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 1e-4, buffer_size: int = 1_000_000, # 1e6 - learning_starts: int = 50000, + learning_starts: int = 100, batch_size: int = 32, tau: float = 1.0, gamma: float = 0.99, diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index a06ce67e0..a61d954bc 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -83,11 +83,11 @@ def __init__( learning_rate: Union[float, Schedule] = 1e-3, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, - batch_size: int = 100, + batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = (1, "episode"), - gradient_steps: int = -1, + train_freq: Union[int, Tuple[int, str]] = 1, + gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 00b35529e..4d04ad95c 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.3.0a0 +2.3.0a1 From 620e58e61f649d0f415b7796386d6fe405778026 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 30 Jan 2024 15:53:25 +0100 Subject: [PATCH 03/30] Update SB3 ONNX export documentation (#1816) --- docs/guide/export.rst | 85 +++++++++++++++++++++++------------------ docs/guide/rl_tips.rst | 8 +++- docs/misc/changelog.rst | 2 + 3 files changed, 57 insertions(+), 38 deletions(-) diff --git a/docs/guide/export.rst b/docs/guide/export.rst index cccf30014..88a02fe8b 100644 --- a/docs/guide/export.rst +++ b/docs/guide/export.rst @@ -31,53 +31,52 @@ to do inference in another framework. Export to ONNX ----------------- -As of June 2021, ONNX format `doesn't support `_ exporting models that use the ``broadcast_tensors`` functionality of pytorch. So in order to export the trained stable-baseline3 models in the ONNX format, we need to first remove the layers that use broadcasting. This can be done by creating a class that removes the unsupported layers. -The following examples are for ``MlpPolicy`` only, and are general examples. Note that you have to preprocess the observation the same way stable-baselines3 agent does (see ``common.preprocessing.preprocess_obs``). +If you are using PyTorch 2.0+ and ONNX Opset 14+, you can easily export SB3 policies using the following code: -For PPO, assuming a shared feature extractor. .. warning:: - The following example is for continuous actions only. - When using discrete or binary actions, you must do some `post-processing `_ - to obtain the action (e.g., convert action logits to action). + The following returns normalized actions and doesn't include the `post-processing `_ step that is done with continuous actions + (clip or unscale the action to the correct space). .. code-block:: python import torch as th + from typing import Tuple from stable_baselines3 import PPO + from stable_baselines3.common.policies import BasePolicy - class OnnxablePolicy(th.nn.Module): - def __init__(self, extractor, action_net, value_net): + class OnnxableSB3Policy(th.nn.Module): + def __init__(self, policy: BasePolicy): super().__init__() - self.extractor = extractor - self.action_net = action_net - self.value_net = value_net + self.policy = policy - def forward(self, observation): - # NOTE: You may have to process (normalize) observation in the correct - # way before using this. See `common.preprocessing.preprocess_obs` - action_hidden, value_hidden = self.extractor(observation) - return self.action_net(action_hidden), self.value_net(value_hidden) + def forward(self, observation: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + # NOTE: Preprocessing is included, but postprocessing + # (clipping/inscaling actions) is not, + # If needed, you also need to transpose the images so that they are channel first + # use deterministic=False if you want to export the stochastic policy + # policy() returns `actions, values, log_prob` for PPO + return self.policy(observation, deterministic=True) # Example: model = PPO("MlpPolicy", "Pendulum-v1") + PPO("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel") model = PPO.load("PathToTrainedModel.zip", device="cpu") - onnxable_model = OnnxablePolicy( - model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net - ) + + onnx_policy = OnnxableSB3Policy(model.policy) observation_size = model.observation_space.shape dummy_input = th.randn(1, *observation_size) th.onnx.export( - onnxable_model, + onnx_policy, dummy_input, "my_ppo_model.onnx", - opset_version=9, + opset_version=17, input_names=["input"], ) @@ -93,7 +92,13 @@ For PPO, assuming a shared feature extractor. observation = np.zeros((1, *observation_size)).astype(np.float32) ort_sess = ort.InferenceSession(onnx_path) - action, value = ort_sess.run(None, {"input": observation}) + actions, values, log_prob = ort_sess.run(None, {"input": observation}) + + print(actions, values, log_prob) + + # Check that the predictions are the same + with th.no_grad(): + print(model.policy(th.as_tensor(observation), deterministic=True)) For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies. @@ -108,23 +113,16 @@ For SAC the procedure is similar. The example shown only exports the actor netwo class OnnxablePolicy(th.nn.Module): def __init__(self, actor: th.nn.Module): super().__init__() - # Removing the flatten layer because it can't be onnxed - self.actor = th.nn.Sequential( - actor.latent_pi, - actor.mu, - # For gSDE - # th.nn.Hardtanh(min_val=-actor.clip_mean, max_val=actor.clip_mean), - # Squash the output - th.nn.Tanh(), - ) + self.actor = actor def forward(self, observation: th.Tensor) -> th.Tensor: - # NOTE: You may have to process (normalize) observation in the correct - # way before using this. See `common.preprocessing.preprocess_obs` - return self.actor(observation) + # NOTE: You may have to postprocess (unnormalize) actions + # to the correct bounds (see commented code below) + return self.actor(observation, deterministic=True) # Example: model = SAC("MlpPolicy", "Pendulum-v1") + SAC("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel.zip") model = SAC.load("PathToTrainedModel.zip", device="cpu") onnxable_model = OnnxablePolicy(model.policy.actor) @@ -134,7 +132,7 @@ For SAC the procedure is similar. The example shown only exports the actor netwo onnxable_model, dummy_input, "my_sac_actor.onnx", - opset_version=9, + opset_version=17, input_names=["input"], ) @@ -147,10 +145,23 @@ For SAC the procedure is similar. The example shown only exports the actor netwo observation = np.zeros((1, *observation_size)).astype(np.float32) ort_sess = ort.InferenceSession(onnx_path) - action = ort_sess.run(None, {"input": observation}) + scaled_action = ort_sess.run(None, {"input": observation})[0] + + print(scaled_action) + + # Post-process: rescale to correct space + # Rescale the action from [-1, 1] to [low, high] + # low, high = model.action_space.low, model.action_space.high + # post_processed_action = low + (0.5 * (scaled_action + 1.0) * (high - low)) + + # Check that the predictions are the same + with th.no_grad(): + print(model.actor(th.as_tensor(observation), deterministic=True)) + + +For more discussion around the topic, please refer to `GH#383 `_ and `GH#1349 `_. -For more discussion around the topic refer to this `issue. `_ Trace/Export to C++ ------------------- diff --git a/docs/guide/rl_tips.rst b/docs/guide/rl_tips.rst index ce6f43e55..ae37640c7 100644 --- a/docs/guide/rl_tips.rst +++ b/docs/guide/rl_tips.rst @@ -252,6 +252,12 @@ A better solution would be to use a squashing function (cf ``SAC``) or a Beta di Tips and Tricks when implementing an RL algorithm ================================================= +.. note:: + + We have a `video on YouTube about reliable RL `_ that covers + this section in more details. You can also find the `slides online `_. + + When you try to reproduce a RL paper by implementing the algorithm, the `nuts and bolts of RL research `_ by John Schulman are quite useful (`video `_). @@ -282,4 +288,4 @@ in RL with discrete actions: 3. Pong (one of the easiest Atari game) 4. other Atari games (e.g. Breakout) -.. _SBX: https://github.com/araffin/sbx \ No newline at end of file +.. _SBX: https://github.com/araffin/sbx diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index a4d8e6373..4f22c672b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -59,6 +59,8 @@ Documentation: ^^^^^^^^^^^^^^ - Added a paragraph on modifying vectorized environment parameters via setters (@fracapuano) - Updated callback code example +- Updated export to ONNX documentation, it is now much simpler to export SB3 models with newer ONNX Opset! +- Added video link to "Practical Tips for Reliable Reinforcement Learning" video Release 2.2.1 (2023-11-17) -------------------------- From beee4279eb465a9eb68cebf5664f3ba0e70088fa Mon Sep 17 00:00:00 2001 From: Marek Michalik Date: Tue, 13 Feb 2024 10:47:05 +0100 Subject: [PATCH 04/30] Fix example in README.md (#1830) * Fix example in README.md * Update changelog --------- Co-authored-by: Antonin Raffin --- README.md | 2 +- docs/misc/changelog.rst | 2 ++ setup.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4f427087b..6e55f1030 100644 --- a/README.md +++ b/README.md @@ -127,7 +127,7 @@ import gymnasium as gym from stable_baselines3 import PPO -env = gym.make("CartPole-v1") +env = gym.make("CartPole-v1", render_mode="human") model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10_000) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 4f22c672b..006f156ef 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -61,6 +61,7 @@ Documentation: - Updated callback code example - Updated export to ONNX documentation, it is now much simpler to export SB3 models with newer ONNX Opset! - Added video link to "Practical Tips for Reliable Reinforcement Learning" video +- Added ``render_mode="human"`` in the README example (@marekm4) Release 2.2.1 (2023-11-17) -------------------------- @@ -1561,3 +1562,4 @@ And all the contributors: @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger +@marekm4 diff --git a/setup.py b/setup.py index 5e10ed66c..817fae22a 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ from stable_baselines3 import PPO -env = gymnasium.make("CartPole-v1") +env = gymnasium.make("CartPole-v1", render_mode="human") model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10_000) From 1cba1bbd2f129f3e3140d6a1e478dd4b3979a2bf Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 13 Feb 2024 11:36:05 +0100 Subject: [PATCH 05/30] Update to black style v24 (#1834) --- docs/misc/changelog.rst | 3 ++- setup.py | 2 +- stable_baselines3/common/save_util.py | 1 + stable_baselines3/common/type_aliases.py | 1 + stable_baselines3/common/vec_env/util.py | 1 + stable_baselines3/common/vec_env/vec_frame_stack.py | 7 ++++++- stable_baselines3/version.txt | 2 +- 7 files changed, 13 insertions(+), 4 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 006f156ef..cf101af45 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.3.0a1 (WIP) +Release 2.3.0a2 (WIP) -------------------------- Breaking Changes: @@ -54,6 +54,7 @@ Deprecations: Others: ^^^^^^^ +- Updated black from v23 to v24 Documentation: ^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index 817fae22a..763a6a376 100644 --- a/setup.py +++ b/setup.py @@ -122,7 +122,7 @@ # Lint code and sort imports (flake8 and isort replacement) "ruff>=0.0.288", # Reformat - "black>=23.9.1,<24", + "black>=24.2.0,<25", ], "docs": [ "sphinx>=5,<8", diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index 0cbf6d4e2..9fca6a832 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -2,6 +2,7 @@ Save util taken from stable_baselines used to serialize data (class parameters) of model classes """ + import base64 import functools import io diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index d75e11531..85d09066e 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -1,4 +1,5 @@ """Common aliases for type hints""" + from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Protocol, SupportsFloat, Tuple, Union diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 2a03d8e70..855f50edc 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -1,6 +1,7 @@ """ Helpers for dealing with vectorized environments. """ + from collections import OrderedDict from typing import Any, Dict, List, Tuple diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py index d412a96a2..daa2b365c 100644 --- a/stable_baselines3/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -29,7 +29,12 @@ def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[st def step_wait( self, - ) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]: + ) -> Tuple[ + Union[np.ndarray, Dict[str, np.ndarray]], + np.ndarray, + np.ndarray, + List[Dict[str, Any]], + ]: observations, rewards, dones, infos = self.venv.step_wait() observations, infos = self.stacked_obs.update(observations, dones, infos) # type: ignore[arg-type] return observations, rewards, dones, infos diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 4d04ad95c..34109b68e 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.3.0a1 +2.3.0a2 From a8e905977f3073066eb332f063f6335f355c455a Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 19 Feb 2024 16:44:02 +0100 Subject: [PATCH 06/30] Update env checker for spaces with non-zero start (#1845) * Update ruff * Update env checker for non-zero start --- Makefile | 2 +- docs/misc/changelog.rst | 4 ++- pyproject.toml | 6 ++-- setup.py | 2 +- stable_baselines3/common/env_checker.py | 44 ++++++++++++++++--------- stable_baselines3/version.txt | 2 +- tests/test_envs.py | 6 +++- 7 files changed, 43 insertions(+), 23 deletions(-) diff --git a/Makefile b/Makefile index fe9f6ae2e..51a5940c6 100644 --- a/Makefile +++ b/Makefile @@ -18,7 +18,7 @@ type: mypy lint: # stop the build if there are Python syntax errors or undefined names # see https://www.flake8rules.com/ - ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source + ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full # exit-zero treats all errors as warnings. ruff ${LINT_PATHS} --exit-zero diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index cf101af45..feb096af9 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.3.0a2 (WIP) +Release 2.3.0a3 (WIP) -------------------------- Breaking Changes: @@ -55,6 +55,8 @@ Deprecations: Others: ^^^^^^^ - Updated black from v23 to v24 +- Updated ruff to >= v0.2.2 +- Updated env checker for (multi)discrete spaces with non-zero start. Documentation: ^^^^^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index 1195687f4..ce0a14e0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,13 +3,15 @@ line-length = 127 # Assume Python 3.8 target-version = "py38" + +[tool.ruff.lint] # See https://beta.ruff.rs/docs/rules/ select = ["E", "F", "B", "UP", "C90", "RUF"] # B028: Ignore explicit stacklevel` # RUF013: Too many false positives (implicit optional) ignore = ["B028", "RUF013"] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Default implementation in abstract methods "./stable_baselines3/common/callbacks.py"= ["B027"] "./stable_baselines3/common/noise.py"= ["B027"] @@ -17,7 +19,7 @@ ignore = ["B028", "RUF013"] "./tests/*.py"= ["RUF012", "RUF013"] -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 15 diff --git a/setup.py b/setup.py index 763a6a376..a077738ff 100644 --- a/setup.py +++ b/setup.py @@ -120,7 +120,7 @@ # Type check "mypy", # Lint code and sort imports (flake8 and isort replacement) - "ruff>=0.0.288", + "ruff>=0.2.2", # Reformat "black>=24.2.0,<25", ], diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index dc465a1d6..f24c86ec9 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -17,13 +17,37 @@ def _is_numpy_array_space(space: spaces.Space) -> bool: return not isinstance(space, (spaces.Dict, spaces.Tuple)) +def _starts_at_zero(space: Union[spaces.Discrete, spaces.MultiDiscrete]) -> bool: + """ + Return False if a (Multi)Discrete space has a non-zero start. + """ + return np.allclose(space.start, np.zeros_like(space.start)) + + +def _check_non_zero_start(space: spaces.Space, space_type: str = "observation", key: str = "") -> None: + """ + :param space: Observation or action space + :param space_type: information about whether it is an observation or action space + (for the warning message) + :param key: When the observation space comes from a Dict space, we pass the + corresponding key to have more precise warning messages. Defaults to "". + """ + if isinstance(space, (spaces.Discrete, spaces.MultiDiscrete)) and not _starts_at_zero(space): + maybe_key = f"(key='{key}')" if key else "" + warnings.warn( + f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) " + "is not supported by Stable-Baselines3. " + f"You can use a wrapper or update your {space_type} space." + ) + + def _check_image_input(observation_space: spaces.Box, key: str = "") -> None: """ Check that the input will be compatible with Stable-Baselines when the observation is apparently an image. :param observation_space: Observation space - :key: When the observation space comes from a Dict space, we pass the + :param key: When the observation space comes from a Dict space, we pass the corresponding key to have more precise warning messages. Defaults to "". """ if observation_space.dtype != np.uint8: @@ -63,11 +87,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act for key, space in observation_space.spaces.items(): if isinstance(space, spaces.Dict): nested_dict = True - if isinstance(space, spaces.Discrete) and space.start != 0: - warnings.warn( - f"Discrete observation space (key '{key}') with a non-zero start is not supported by Stable-Baselines3. " - "You can use a wrapper or update your observation space." - ) + _check_non_zero_start(space, "observation", key) if nested_dict: warnings.warn( @@ -87,11 +107,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act "which is supported by SB3." ) - if isinstance(observation_space, spaces.Discrete) and observation_space.start != 0: - warnings.warn( - "Discrete observation space with a non-zero start is not supported by Stable-Baselines3. " - "You can use a wrapper or update your observation space." - ) + _check_non_zero_start(observation_space, "observation") if isinstance(observation_space, spaces.Sequence): warnings.warn( @@ -100,11 +116,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act "Note: The checks for returned values are skipped." ) - if isinstance(action_space, spaces.Discrete) and action_space.start != 0: - warnings.warn( - "Discrete action space with a non-zero start is not supported by Stable-Baselines3. " - "You can use a wrapper or update your action space." - ) + _check_non_zero_start(action_space, "action") if not _is_numpy_array_space(action_space): warnings.warn( diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 34109b68e..5334cfa48 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.3.0a2 +2.3.0a3 diff --git a/tests/test_envs.py b/tests/test_envs.py index e82ef5768..9a61eeef0 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -123,6 +123,8 @@ def patched_step(_action): spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}), # Non zero start index spaces.Discrete(3, start=-1), + # Non zero start index (MultiDiscrete) + spaces.MultiDiscrete([4, 4], start=[1, 0]), # Non zero start index inside a Dict spaces.Dict({"obs": spaces.Discrete(3, start=1)}), ], @@ -164,6 +166,8 @@ def patched_step(_action): spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32), # Non zero start index spaces.Discrete(3, start=-1), + # Non zero start index (MultiDiscrete) + spaces.MultiDiscrete([4, 4], start=[1, 0]), ], ) def test_non_default_action_spaces(new_action_space): @@ -179,7 +183,7 @@ def test_non_default_action_spaces(new_action_space): env.action_space = new_action_space # Discrete action space - if isinstance(new_action_space, spaces.Discrete): + if isinstance(new_action_space, (spaces.Discrete, spaces.MultiDiscrete)): with pytest.warns(UserWarning): check_env(env) return From 56f20e40a2206bbb16501a0f600e29ce1b112ef1 Mon Sep 17 00:00:00 2001 From: StagOverflow <62816062+StagOverflow@users.noreply.github.com> Date: Tue, 27 Feb 2024 08:49:42 -0500 Subject: [PATCH 07/30] Fix `sum_independent_dims` docstring to reflect output shape (#1851) Co-authored-by: Heinrick Lumini --- docs/misc/changelog.rst | 3 ++- stable_baselines3/common/distributions.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index feb096af9..5176c6971 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -65,6 +65,7 @@ Documentation: - Updated export to ONNX documentation, it is now much simpler to export SB3 models with newer ONNX Opset! - Added video link to "Practical Tips for Reliable Reinforcement Learning" video - Added ``render_mode="human"`` in the README example (@marekm4) +- Fixed docstring signature for sum_independent_dims (@stagoverflow) Release 2.2.1 (2023-11-17) -------------------------- @@ -1565,4 +1566,4 @@ And all the contributors: @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger -@marekm4 +@marekm4 @stagoverflow diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 149345d83..132a35348 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -113,7 +113,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor: so we can sum components of the ``log_prob`` or the entropy. :param tensor: shape: (n_batch, n_actions) or (n_batch,) - :return: shape: (n_batch,) + :return: shape: (n_batch,) for (n_batch, n_actions) input, scalar for (n_batch,) input """ if len(tensor.shape) > 1: tensor = tensor.sum(dim=1) From f375cc393938a6b4e4dc0fb1de82b4afca37c1bd Mon Sep 17 00:00:00 2001 From: Rushit Shah <29002479+rushitnshah@users.noreply.github.com> Date: Mon, 4 Mar 2024 04:42:16 -0600 Subject: [PATCH 08/30] Fix docstring for ``log_interval`` to differentiate between on-policy/off-policy logging frequency (#1855) * Fix docstring for log_interval inside the learn method in the base class. * Updated changelog. * Update docstring --------- Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 3 ++- stable_baselines3/common/base_class.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 5176c6971..e67e5edee 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -66,6 +66,7 @@ Documentation: - Added video link to "Practical Tips for Reliable Reinforcement Learning" video - Added ``render_mode="human"`` in the README example (@marekm4) - Fixed docstring signature for sum_independent_dims (@stagoverflow) +- Updated docstring description for ``log_interval`` in the base class (@rushitnshah). Release 2.2.1 (2023-11-17) -------------------------- @@ -1566,4 +1567,4 @@ And all the contributors: @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger -@marekm4 @stagoverflow +@marekm4 @stagoverflow @rushitnshah diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 5e8759990..e6c7d3cfc 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -523,7 +523,10 @@ def learn( :param total_timesteps: The total number of samples (env steps) to train on :param callback: callback(s) called at every step with state of the algorithm. - :param log_interval: The number of episodes before logging. + :param log_interval: for on-policy algos (e.g., PPO, A2C, ...) this is the number of + training iterations (i.e., log_interval * n_steps * n_envs timesteps) before logging; + for off-policy algos (e.g., TD3, SAC, ...) this is the number of episodes before + logging. :param tb_log_name: the name of the run for TensorBoard logging :param reset_num_timesteps: whether or not to reset the current timestep number (used in logging) :param progress_bar: Display a progress bar using tqdm and rich. From 8b3723c6d8420bb978f4d68409ff5189f87fe107 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 11 Mar 2024 13:53:06 +0100 Subject: [PATCH 09/30] Update ruff and documentation for hf sb3 (#1866) * Update ruff * Only load weights with `torch.load()` to avoid security issues * Update doc about HF integration and remote code execution * Fix doc build * Revert weight_only=True for policies --- Makefile | 8 ++++---- docs/guide/integrations.rst | 11 ++++++++++- docs/misc/changelog.rst | 16 +++++++++++++--- setup.py | 2 +- stable_baselines3/common/policies.py | 4 +++- stable_baselines3/common/save_util.py | 2 +- stable_baselines3/version.txt | 2 +- 7 files changed, 33 insertions(+), 12 deletions(-) diff --git a/Makefile b/Makefile index 51a5940c6..e0f6b2b0c 100644 --- a/Makefile +++ b/Makefile @@ -18,19 +18,19 @@ type: mypy lint: # stop the build if there are Python syntax errors or undefined names # see https://www.flake8rules.com/ - ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full + ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full # exit-zero treats all errors as warnings. - ruff ${LINT_PATHS} --exit-zero + ruff check ${LINT_PATHS} --exit-zero format: # Sort imports - ruff --select I ${LINT_PATHS} --fix + ruff check --select I ${LINT_PATHS} --fix # Reformat using black black ${LINT_PATHS} check-codestyle: # Sort imports - ruff --select I ${LINT_PATHS} + ruff check --select I ${LINT_PATHS} # Reformat using black black --check ${LINT_PATHS} diff --git a/docs/guide/integrations.rst b/docs/guide/integrations.rst index 14573cdec..9f864a2e0 100644 --- a/docs/guide/integrations.rst +++ b/docs/guide/integrations.rst @@ -70,8 +70,10 @@ Installation .. code-block:: bash + # Download model and save it into the logs/ folder - python -m rl_zoo3.load_from_hub --algo a2c --env LunarLander-v2 -orga sb3 -f logs/ + # Only use TRUST_REMOTE_CODE=True with HF models that can be trusted (here the SB3 organization) + TRUST_REMOTE_CODE=True python -m rl_zoo3.load_from_hub --algo a2c --env LunarLander-v2 -orga sb3 -f logs/ # Test the agent python -m rl_zoo3.enjoy --algo a2c --env LunarLander-v2 -f logs/ # Push model, config and hyperparameters to the hub @@ -86,12 +88,19 @@ For instance ``sb3/demo-hf-CartPole-v1``: .. code-block:: python + import os + import gymnasium as gym from huggingface_sb3 import load_from_hub from stable_baselines3 import PPO from stable_baselines3.common.evaluation import evaluate_policy + + # Allow the use of `pickle.load()` when downloading model from the hub + # Please make sure that the organization from which you download can be trusted + os.environ["TRUST_REMOTE_CODE"] = "True" + # Retrieve the model from the hub ## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name}) ## filename = name of the model zip file from the repository diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e67e5edee..9ed1cfc09 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.3.0a3 (WIP) +Release 2.3.0a4 (WIP) -------------------------- Breaking Changes: @@ -33,6 +33,11 @@ Breaking Changes: # SB3 >= 2.3.0: model = DQN("MlpPolicy", env, learning_start=100) +- For safety, ``torch.load()`` is now called with ``weights_only=True`` when loading torch tensors, + policy ``load()`` still uses ``weights_only=False`` as gymnasium imports are required for it to work +- When using ``huggingface_sb3``, you will now need to set ``TRUST_REMOTE_CODE=True`` when downloading models from the hub, + as ``pickle.load`` is not safe. + New Features: ^^^^^^^^^^^^^ @@ -48,6 +53,11 @@ Bug Fixes: `SBX`_ (SB3 + Jax) ^^^^^^^^^^^^^^^^^^ +- Added support for ``MultiDiscrete`` and ``MultiBinary`` action spaces to PPO +- Added support for large values for gradient_steps to SAC, TD3, and TQC +- Fix ``train()`` signature and update type hints +- Fix replay buffer device at load time +- Added flatten layer Deprecations: ^^^^^^^^^^^^^ @@ -55,7 +65,7 @@ Deprecations: Others: ^^^^^^^ - Updated black from v23 to v24 -- Updated ruff to >= v0.2.2 +- Updated ruff to >= v0.3.1 - Updated env checker for (multi)discrete spaces with non-zero start. Documentation: @@ -66,7 +76,7 @@ Documentation: - Added video link to "Practical Tips for Reliable Reinforcement Learning" video - Added ``render_mode="human"`` in the README example (@marekm4) - Fixed docstring signature for sum_independent_dims (@stagoverflow) -- Updated docstring description for ``log_interval`` in the base class (@rushitnshah). +- Updated docstring description for ``log_interval`` in the base class (@rushitnshah). Release 2.2.1 (2023-11-17) -------------------------- diff --git a/setup.py b/setup.py index a077738ff..161539af1 100644 --- a/setup.py +++ b/setup.py @@ -120,7 +120,7 @@ # Type check "mypy", # Lint code and sort imports (flake8 and isort replacement) - "ruff>=0.2.2", + "ruff>=0.3.1", # Reformat "black>=24.2.0,<25", ], diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 50be01c9e..e4d62ef0f 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -173,7 +173,9 @@ def load(cls: Type[SelfBaseModel], path: str, device: Union[th.device, str] = "a :return: """ device = get_device(device) - saved_variables = th.load(path, map_location=device) + # Note(antonin): we cannot use `weights_only=True` here because we need to allow + # gymnasium imports for the policy to be loaded successfully + saved_variables = th.load(path, map_location=device, weights_only=False) # Create policy object model = cls(**saved_variables["data"]) diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index 9fca6a832..2d8652006 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -447,7 +447,7 @@ def load_from_zip_file( file_content.seek(0) # Load the parameters with the right ``map_location``. # Remove ".pth" ending with splitext - th_object = th.load(file_content, map_location=device) + th_object = th.load(file_content, map_location=device, weights_only=True) # "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138 if file_path == "pytorch_variables.pth" or file_path == "tensors.pth": # PyTorch variables (not state_dicts) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 5334cfa48..87ced0fe9 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.3.0a3 +2.3.0a4 From 071226d3e83d2728153c37dcf964a0e5d4d967b9 Mon Sep 17 00:00:00 2001 From: Corentin <111868204+corentinlger@users.noreply.github.com> Date: Fri, 22 Mar 2024 12:13:48 +0100 Subject: [PATCH 10/30] Log success rate for on policy algorithms (#1870) * Add success rate in monitor for on policy algorithms * Update changelog * make commit-checks refactoring * Assert buffers are not none in _dump_logs * Automatic refactoring of the type hinting * Add success_rate logging test for on policy algorithms * Update changelog * Reformat * Fix tests and update changelog --------- Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 4 +- .../common/on_policy_algorithm.py | 36 +++++--- stable_baselines3/version.txt | 2 +- tests/test_logger.py | 92 ++++++++++++++++++- 4 files changed, 120 insertions(+), 14 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 9ed1cfc09..842db827c 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.3.0a4 (WIP) +Release 2.3.0a5 (WIP) -------------------------- Breaking Changes: @@ -41,9 +41,11 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Log success rate ``rollout/success_rate`` when available for on policy algorithms (@corentinlger) Bug Fixes: ^^^^^^^^^^ +- Fixed ``monitor_wrapper`` argument that was not passed to the parent class, and dones argument that wasn't passed to ``_update_into_buffer`` (@corentinlger) `SB3-Contrib`_ ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index ddd0f8de2..1ba36d5f0 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -92,6 +92,7 @@ def __init__( use_sde=use_sde, sde_sample_freq=sde_sample_freq, support_multi_env=True, + monitor_wrapper=monitor_wrapper, seed=seed, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, @@ -200,7 +201,7 @@ def collect_rollouts( if not callback.on_step(): return False - self._update_info_buffer(infos) + self._update_info_buffer(infos, dones) n_steps += 1 if isinstance(self.action_space, spaces.Discrete): @@ -250,6 +251,28 @@ def train(self) -> None: """ raise NotImplementedError + def _dump_logs(self, iteration: int) -> None: + """ + Write log. + + :param iteration: Current logging iteration + """ + assert self.ep_info_buffer is not None + assert self.ep_success_buffer is not None + + time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) + fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) + self.logger.record("time/iterations", iteration, exclude="tensorboard") + if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: + self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) + self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) + self.logger.record("time/fps", fps) + self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") + self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") + if len(self.ep_success_buffer) > 0: + self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer)) + self.logger.dump(step=self.num_timesteps) + def learn( self: SelfOnPolicyAlgorithm, total_timesteps: int, @@ -285,16 +308,7 @@ def learn( # Display training infos if log_interval is not None and iteration % log_interval == 0: assert self.ep_info_buffer is not None - time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) - fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) - self.logger.record("time/iterations", iteration, exclude="tensorboard") - if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: - self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) - self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) - self.logger.record("time/fps", fps) - self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") - self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") - self.logger.dump(step=self.num_timesteps) + self._dump_logs(iteration) self.train() diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 87ced0fe9..a3b489b55 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.3.0a4 +2.3.0a5 diff --git a/tests/test_logger.py b/tests/test_logger.py index 05bf196a3..dfd9e5567 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -14,7 +14,7 @@ from matplotlib import pyplot as plt from pandas.errors import EmptyDataError -from stable_baselines3 import A2C, DQN +from stable_baselines3 import A2C, DQN, PPO from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.logger import ( DEBUG, @@ -33,6 +33,7 @@ read_csv, read_json, ) +from stable_baselines3.common.monitor import Monitor KEY_VALUES = { "test": 1, @@ -474,3 +475,92 @@ def get_printed(self) -> str: """ assert printed == desired_printed + + +class DummySuccessEnv(gym.Env): + """ + Create a dummy success environment that returns wether True or False for info['is_success'] + at the end of an episode according to its dummy successes list + """ + + def __init__(self, dummy_successes, ep_steps): + """Init the dummy success env + + :param dummy_successes: list of size (n_logs_iterations, n_episodes_per_log) that specifies + the success value of log iteration i at episode j + :param ep_steps: number of steps per episode (to activate truncated) + """ + self.n_steps = 0 + self.log_id = 0 + self.ep_id = 0 + + self.ep_steps = ep_steps + + self.dummy_success = dummy_successes + self.num_logs = len(dummy_successes) + self.ep_per_log = len(dummy_successes[0]) + self.steps_per_log = self.ep_per_log * self.ep_steps + + self.action_space = spaces.Discrete(2) + self.observation_space = spaces.Discrete(2) + + def reset(self, seed=None, options=None): + """ + Reset the env and advance to the next episode_id to get the next dummy success + """ + self.n_steps = 0 + + if self.ep_id == self.ep_per_log: + self.ep_id = 0 + self.log_id = (self.log_id + 1) % self.num_logs + + return self.observation_space.sample(), {} + + def step(self, action): + """ + Step and return a dummy success when an episode is truncated + """ + self.n_steps += 1 + truncated = self.n_steps >= self.ep_steps + + info = {} + if truncated: + maybe_success = self.dummy_success[self.log_id][self.ep_id] + info["is_success"] = maybe_success + self.ep_id += 1 + return self.observation_space.sample(), 0.0, False, truncated, info + + +def test_rollout_success_rate_on_policy_algorithm(tmp_path): + """ + Test if the rollout/success_rate information is correctly logged with on policy algorithms + + To do so, create a dummy environment that takes as argument dummy successes (i.e when an episode) + is going to be successfull or not. + """ + + STATS_WINDOW_SIZE = 10 + # Add dummy successes with 0.3, 0.5 and 0.8 success_rate of length STATS_WINDOW_SIZE + dummy_successes = [ + [True] * 3 + [False] * 7, + [True] * 5 + [False] * 5, + [True] * 8 + [False] * 2, + ] + ep_steps = 64 + + # Monitor the env to track the success info + monitor_file = str(tmp_path / "monitor.csv") + env = Monitor(DummySuccessEnv(dummy_successes, ep_steps), filename=monitor_file, info_keywords=("is_success",)) + + # Equip the model of a custom logger to check the success_rate info + model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=env.steps_per_log, verbose=1) + logger = InMemoryLogger() + model.set_logger(logger) + + # Make the model learn and check that the success rate corresponds to the ratio of dummy successes + model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + assert logger.name_to_value["rollout/success_rate"] == 0.3 + model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + assert logger.name_to_value["rollout/success_rate"] == 0.5 + model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + assert logger.name_to_value["rollout/success_rate"] == 0.8 From 429be93c48502c634148a6fa1e2a0e421bcd5a20 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sun, 31 Mar 2024 20:25:19 +0200 Subject: [PATCH 11/30] Release v2.3.0 (#1879) * Release v2.3.0 * Fix typos --- docs/misc/changelog.rst | 24 +++++++++++++++++++----- stable_baselines3/version.txt | 2 +- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 842db827c..376585a66 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,9 +3,12 @@ Changelog ========== -Release 2.3.0a5 (WIP) +Release 2.3.0 (2024-03-31) -------------------------- +**New defaults hyperparameters for DDPG, TD3 and DQN** + + Breaking Changes: ^^^^^^^^^^^^^^^^^ - The defaults hyperparameters of ``TD3`` and ``DDPG`` have been changed to be more consistent with ``SAC`` @@ -19,11 +22,11 @@ Breaking Changes: .. note:: - Two inconsistencies remains: the default network architecture for ``TD3/DDPG`` is ``[400, 300]`` instead of ``[256, 256]`` for SAC (for backward compatibility reasons, see `report on the influence of the network size `_) and the default learning rate is 1e-3 instead of 3e-4 for SAC (for performance reasons, see `W&B report on the influence of the lr `_) + Two inconsistencies remain: the default network architecture for ``TD3/DDPG`` is ``[400, 300]`` instead of ``[256, 256]`` for SAC (for backward compatibility reasons, see `report on the influence of the network size `_) and the default learning rate is 1e-3 instead of 3e-4 for SAC (for performance reasons, see `W&B report on the influence of the lr `_) -- The default ``leanrning_starts`` parameter of ``DQN`` have been changed to be consistent with the other offpolicy algorithms +- The default ``learning_starts`` parameter of ``DQN`` have been changed to be consistent with the other offpolicy algorithms .. code-block:: python @@ -35,8 +38,7 @@ Breaking Changes: - For safety, ``torch.load()`` is now called with ``weights_only=True`` when loading torch tensors, policy ``load()`` still uses ``weights_only=False`` as gymnasium imports are required for it to work -- When using ``huggingface_sb3``, you will now need to set ``TRUST_REMOTE_CODE=True`` when downloading models from the hub, - as ``pickle.load`` is not safe. +- When using ``huggingface_sb3``, you will now need to set ``TRUST_REMOTE_CODE=True`` when downloading models from the hub, as ``pickle.load`` is not safe. New Features: @@ -49,9 +51,20 @@ Bug Fixes: `SB3-Contrib`_ ^^^^^^^^^^^^^^ +- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to MaskablePPO +- Fixed ``train_freq`` type annotation for tqc and qrdqn (@Armandpl) +- Fixed ``sb3_contrib/common/maskable/*.py`` type annotations +- Fixed ``sb3_contrib/ppo_mask/ppo_mask.py`` type annotations +- Fixed ``sb3_contrib/common/vec_env/async_eval.py`` type annotations +- Add some additional notes about ``MaskablePPO`` (evaluation and multi-process) (@icheered) + `RL Zoo`_ ^^^^^^^^^ +- Updated defaults hyperparameters for TD3/DDPG to be more consistent with SAC +- Upgraded MuJoCo envs hyperparameters to v4 (pre-trained agents need to be updated) +- Added test dependencies to `setup.py` (@power-edge) +- Simplify dependencies of `requirements.txt` (remove duplicates from `setup.py`) `SBX`_ (SB3 + Jax) ^^^^^^^^^^^^^^^^^^ @@ -60,6 +73,7 @@ Bug Fixes: - Fix ``train()`` signature and update type hints - Fix replay buffer device at load time - Added flatten layer +- Added ``CrossQ`` Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index a3b489b55..276cbf9e2 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.3.0a5 +2.3.0 From 40ba50467ca962180aaa8a3fefbe03bcfd352909 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 1 Apr 2024 16:07:52 +0200 Subject: [PATCH 12/30] Fix typo in changelog (#1882) --- docs/misc/changelog.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 376585a66..c1560201c 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -32,9 +32,9 @@ Breaking Changes: .. code-block:: python # SB3 < 2.3.0 default hyperparameters, 50_000 corresponded to Atari defaults hyperparameters - # model = DQN("MlpPolicy", env, learning_start=50_000) + # model = DQN("MlpPolicy", env, learning_starts=50_000) # SB3 >= 2.3.0: - model = DQN("MlpPolicy", env, learning_start=100) + model = DQN("MlpPolicy", env, learning_starts=100) - For safety, ``torch.load()`` is now called with ``weights_only=True`` when loading torch tensors, policy ``load()`` still uses ``weights_only=False`` as gymnasium imports are required for it to work From 5623d98f9d6bcfd2ab450e850c3f7b090aef5642 Mon Sep 17 00:00:00 2001 From: Chaitanya Bisht Date: Mon, 8 Apr 2024 19:18:26 +0530 Subject: [PATCH 13/30] Fixed broken link in ppo.rst (#1884) --- docs/modules/ppo.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index ace2fcc63..b5e667241 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -23,7 +23,7 @@ Notes - Original paper: https://arxiv.org/abs/1707.06347 - Clear explanation of PPO on Arxiv Insights channel: https://www.youtube.com/watch?v=5P7I-xPq8u8 -- OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/ +- OpenAI blog post: https://openai.com/research/openai-baselines-ppo - Spinning Up guide: https://spinningup.openai.com/en/latest/algorithms/ppo.html - 37 implementation details blog: https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ From 9a749389d30ac2146046e77f361de24a35395bb6 Mon Sep 17 00:00:00 2001 From: Mark Smith Date: Mon, 22 Apr 2024 04:04:01 -0400 Subject: [PATCH 14/30] Cast learning_rate to float lambda for pickle safety when doing model.load (#1901) * create failing test for unpickle error * Fix learning_rate argument causing failure in weights_only=True if passed a function with non-float types * Updated with feedback from araffin on PR#1901 * Update test and version * Update changelog and SBX doc --------- Co-authored-by: Antonin Raffin --- README.md | 2 +- docs/guide/algos.rst | 3 ++- docs/guide/sbx.rst | 15 +++++++++------ docs/misc/changelog.rst | 14 +++++++++++++- stable_baselines3/common/utils.py | 4 +++- stable_baselines3/version.txt | 2 +- tests/test_save_load.py | 14 ++++++++++++++ 7 files changed, 43 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 6e55f1030..018b8c6f3 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ Documentation is available online: [https://sb3-contrib.readthedocs.io/](https:/ ## Stable-Baselines Jax (SBX) -[Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax. +[Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax, with recent algorithms like DroQ or CrossQ. It provides a minimal number of features compared to SB3 but can be much faster (up to 20x times!): https://twitter.com/araffin2/status/1590714558628253698 diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 33ac3ba46..d5e7ae1d2 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -43,7 +43,8 @@ Actions ``gym.spaces``: .. note:: - More algorithms (like QR-DQN or TQC) are implemented in our :ref:`contrib repo `. + More algorithms (like QR-DQN or TQC) are implemented in our :ref:`contrib repo ` + and in our :ref:`SBX (SB3 + Jax) repo ` (DroQ, CrossQ, ...). .. note:: diff --git a/docs/guide/sbx.rst b/docs/guide/sbx.rst index 52b4348bc..ed5369ea4 100644 --- a/docs/guide/sbx.rst +++ b/docs/guide/sbx.rst @@ -17,6 +17,7 @@ Implemented algorithms: - Deep Q Network (DQN) - Twin Delayed DDPG (TD3) - Deep Deterministic Policy Gradient (DDPG) +- Batch Normalization in Deep Reinforcement Learning (CrossQ) As SBX follows SB3 API, it is also compatible with the `RL Zoo `_. @@ -29,16 +30,17 @@ For that you will need to create two files: import rl_zoo3 import rl_zoo3.train from rl_zoo3.train import train - - from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ + from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ rl_zoo3.ALGOS["ddpg"] = DDPG rl_zoo3.ALGOS["dqn"] = DQN - rl_zoo3.ALGOS["droq"] = DroQ + # See SBX readme to use DroQ configuration + # rl_zoo3.ALGOS["droq"] = DroQ rl_zoo3.ALGOS["sac"] = SAC rl_zoo3.ALGOS["ppo"] = PPO rl_zoo3.ALGOS["td3"] = TD3 rl_zoo3.ALGOS["tqc"] = TQC + rl_zoo3.ALGOS["crossq"] = CrossQ rl_zoo3.train.ALGOS = rl_zoo3.ALGOS rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS @@ -56,16 +58,17 @@ Then you can call ``python train_sbx.py --algo sac --env Pendulum-v1`` and use t import rl_zoo3 import rl_zoo3.enjoy from rl_zoo3.enjoy import enjoy - - from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ + from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ rl_zoo3.ALGOS["ddpg"] = DDPG rl_zoo3.ALGOS["dqn"] = DQN - rl_zoo3.ALGOS["droq"] = DroQ + # See SBX readme to use DroQ configuration + # rl_zoo3.ALGOS["droq"] = DroQ rl_zoo3.ALGOS["sac"] = SAC rl_zoo3.ALGOS["ppo"] = PPO rl_zoo3.ALGOS["td3"] = TD3 rl_zoo3.ALGOS["tqc"] = TQC + rl_zoo3.ALGOS["crossq"] = CrossQ rl_zoo3.enjoy.ALGOS = rl_zoo3.ALGOS rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index c1560201c..9080b6245 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,18 @@ Changelog ========== +Release 2.3.1 (2024-04-22) +-------------------------- + +Bug Fixes: +^^^^^^^^^^ +- Cast return value of learning rate schedule to float, to avoid issue when loading model because of ``weights_only=True`` (@markscsmith) + +Documentation: +^^^^^^^^^^^^^^ +- Updated SBX documentation (CrossQ and deprecated DroQ) + + Release 2.3.0 (2024-03-31) -------------------------- @@ -1593,4 +1605,4 @@ And all the contributors: @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger -@marekm4 @stagoverflow @rushitnshah +@marekm4 @stagoverflow @rushitnshah @markscsmith diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 3ff193786..bcde1cfa0 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -92,7 +92,9 @@ def get_schedule_fn(value_schedule: Union[Schedule, float]) -> Schedule: value_schedule = constant_fn(float(value_schedule)) else: assert callable(value_schedule) - return value_schedule + # Cast to float to avoid unpickling errors to enable weights_only=True, see GH#1900 + # Some types are have odd behaviors when part of a Schedule, like numpy floats + return lambda progress_remaining: float(value_schedule(progress_remaining)) def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 276cbf9e2..2bf1c1ccf 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.3.0 +2.3.1 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index e7123e984..0162e3650 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -783,3 +783,17 @@ def test_no_resource_warning(tmp_path): fp.seek(0) model.load_replay_buffer(fp) assert not fp.closed + + +def test_cast_lr_schedule(tmp_path): + # See GH#1900 + model = PPO("MlpPolicy", "Pendulum-v1", learning_rate=lambda t: t * np.sin(1.0)) + # Note: for recent version of numpy, np.float64 is a subclass of float + # so we need to use type here + # assert isinstance(model.lr_schedule(1.0), float) + assert type(model.lr_schedule(1.0)) is float # noqa: E721 + assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0)) + model.save(tmp_path / "ppo.zip") + model = PPO.load(tmp_path / "ppo.zip") + assert type(model.lr_schedule(1.0)) is float # noqa: E721 + assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0)) From 4af4a32d1b5acb06d585ef7bb0a00c83810fe5c3 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 22 Apr 2024 10:24:53 +0200 Subject: [PATCH 15/30] Update RL Tips and Tricks section --- docs/guide/rl_tips.rst | 43 +++++++++++++++++++++-------------------- docs/misc/changelog.rst | 1 + 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/docs/guide/rl_tips.rst b/docs/guide/rl_tips.rst index ae37640c7..c4f277f3e 100644 --- a/docs/guide/rl_tips.rst +++ b/docs/guide/rl_tips.rst @@ -4,7 +4,7 @@ Reinforcement Learning Tips and Tricks ====================================== -The aim of this section is to help you do reinforcement learning experiments. +The aim of this section is to help you run reinforcement learning experiments. It covers general advice about RL (where to start, which algorithm to choose, how to evaluate an algorithm, ...), as well as tips and tricks when using a custom environment or implementing an RL algorithm. @@ -14,6 +14,11 @@ as well as tips and tricks when using a custom environment or implementing an RL this section in more details. You can also find the `slides here `_. +.. note:: + + We also have a `video on Designing and Running Real-World RL Experiments `_, slides are `can be found online `_. + + General advice when using Reinforcement Learning ================================================ @@ -103,19 +108,19 @@ and this `issue `_ by Cé Which algorithm should I use? ============================= -There is no silver bullet in RL, depending on your needs and problem, you may choose one or the other. +There is no silver bullet in RL, you can choose one or the other depending on your needs and problems. The first distinction comes from your action space, i.e., do you have discrete (e.g. LEFT, RIGHT, ...) or continuous actions (ex: go to a certain speed)? -Some algorithms are only tailored for one or the other domain: ``DQN`` only supports discrete actions, where ``SAC`` is restricted to continuous actions. +Some algorithms are only tailored for one or the other domain: ``DQN`` supports only discrete actions, while ``SAC`` is restricted to continuous actions. -The second difference that will help you choose is whether you can parallelize your training or not. +The second difference that will help you decide is whether you can parallelize your training or not. If what matters is the wall clock training time, then you should lean towards ``A2C`` and its derivatives (PPO, ...). Take a look at the `Vectorized Environments `_ to learn more about training with multiple workers. -To accelerate training, you can also take a look at `SBX`_, which is SB3 + Jax, it has fewer features than SB3 but can be up to 20x faster than SB3 PyTorch thanks to JIT compilation of the gradient update. +To accelerate training, you can also take a look at `SBX`_, which is SB3 + Jax, it has less features than SB3 but can be up to 20x faster than SB3 PyTorch thanks to JIT compilation of the gradient update. -In sparse reward settings, we either recommend to use dedicated methods like HER (see below) or population-based algorithms like ARS (available in our :ref:`contrib repo `). +In sparse reward settings, we either recommend using either dedicated methods like HER (see below) or population-based algorithms like ARS (available in our :ref:`contrib repo `). To sum it up: @@ -146,7 +151,7 @@ Continuous Actions Continuous Actions - Single Process ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Current State Of The Art (SOTA) algorithms are ``SAC``, ``TD3`` and ``TQC`` (available in our :ref:`contrib repo `). +Current State Of The Art (SOTA) algorithms are ``SAC``, ``TD3``, ``CrossQ`` and ``TQC`` (available in our :ref:`contrib repo ` and :ref:`SBX (SB3 + Jax) repo `). Please use the hyperparameters in the `RL zoo `_ for best results. If you want an extremely sample-efficient algorithm, we recommend using the `DroQ configuration `_ in `SBX`_ (it does many gradient steps per step in the environment). @@ -155,8 +160,7 @@ If you want an extremely sample-efficient algorithm, we recommend using the `Dro Continuous Actions - Multiprocessed ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Take a look at ``PPO``, ``TRPO`` (available in our :ref:`contrib repo `) or ``A2C``. Again, don't forget to take the hyperparameters from the `RL zoo `_ -for continuous actions problems (cf *Bullet* envs). +Take a look at ``PPO``, ``TRPO`` (available in our :ref:`contrib repo `) or ``A2C``. Again, don't forget to take the hyperparameters from the `RL zoo `_ for continuous actions problems (cf *Bullet* envs). .. note:: @@ -181,26 +185,23 @@ Tips and Tricks when creating a custom environment ================================================== If you want to learn about how to create a custom environment, we recommend you read this `page `_. -We also provide a `colab notebook `_ for -a concrete example of creating a custom gym environment. +We also provide a `colab notebook `_ for a concrete example of creating a custom gym environment. Some basic advice: -- always normalize your observation space when you can, i.e., when you know the boundaries -- normalize your action space and make it symmetric when continuous (cf potential issue below) A good practice is to rescale your actions to lie in [-1, 1]. This does not limit you as you can easily rescale the action inside the environment -- start with shaped reward (i.e. informative reward) and simplified version of your problem -- debug with random actions to check that your environment works and follows the gym interface: +- always normalize your observation space if you can, i.e. if you know the boundaries +- normalize your action space and make it symmetric if it is continuous (see potential problem below) A good practice is to rescale your actions so that they lie in [-1, 1]. This does not limit you, as you can easily rescale the action within the environment. +- start with a shaped reward (i.e. informative reward) and a simplified version of your problem +- debug with random actions to check if your environment works and follows the gym interface (with ``check_env``, see below) -Two important things to keep in mind when creating a custom environment is to avoid breaking Markov assumption +Two important things to keep in mind when creating a custom environment are avoiding breaking the Markov assumption and properly handle termination due to a timeout (maximum number of steps in an episode). -For instance, if there is some time delay between action and observation (e.g. due to wifi communication), you should give a history of observations -as input. +For example, if there is a time delay between action and observation (e.g. due to wifi communication), you should provide a history of observations as input. Termination due to timeout (max number of steps per episode) needs to be handled separately. You should return ``truncated = True``. If you are using the gym ``TimeLimit`` wrapper, this will be done automatically. -You can read `Time Limit in RL `_ or take a look at the `RL Tips and Tricks video `_ -for more details. +You can read `Time Limit in RL `_, take a look at the `Designing and Running Real-World RL Experiments video `_ or `RL Tips and Tricks video `_ for more details. We provide a helper to check that your environment runs without error: @@ -234,7 +235,7 @@ If you want to quickly try a random agent on your environment, you can also do: Most reinforcement learning algorithms rely on a Gaussian distribution (initially centered at 0 with std 1) for continuous actions. So, if you forget to normalize the action space when using a custom environment, -this can harm learning and be difficult to debug (cf attached image and `issue #473 `_). +this can harm learning and can be difficult to debug (cf attached image and `issue #473 `_). .. figure:: ../_static/img/mistake.png diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 9080b6245..db065ee4a 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -13,6 +13,7 @@ Bug Fixes: Documentation: ^^^^^^^^^^^^^^ - Updated SBX documentation (CrossQ and deprecated DroQ) +- Updated RL Tips and Tricks section Release 2.3.0 (2024-03-31) From e93175084f1dc4770c6d8070bdf4c99010e4f2c0 Mon Sep 17 00:00:00 2001 From: Corentin <111868204+corentinlger@users.noreply.github.com> Date: Thu, 25 Apr 2024 14:31:15 +0200 Subject: [PATCH 16/30] Adding ER-MRL to community project (#1904) * Add ER_MRL * Update changelog * Move ER-MRL at the end of the file * Improve project description * Update changelog --------- Co-authored-by: Antonin Raffin --- docs/guide/rl_tips.rst | 2 +- docs/misc/changelog.rst | 35 +++++++++++++++++++++++++++++++++++ docs/misc/projects.rst | 11 +++++++++++ stable_baselines3/version.txt | 2 +- 4 files changed, 48 insertions(+), 2 deletions(-) diff --git a/docs/guide/rl_tips.rst b/docs/guide/rl_tips.rst index c4f277f3e..3acd1b433 100644 --- a/docs/guide/rl_tips.rst +++ b/docs/guide/rl_tips.rst @@ -16,7 +16,7 @@ as well as tips and tricks when using a custom environment or implementing an RL .. note:: - We also have a `video on Designing and Running Real-World RL Experiments `_, slides are `can be found online `_. + We also have a `video on Designing and Running Real-World RL Experiments `_, slides `can be found online `_. General advice when using Reinforcement Learning diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index db065ee4a..f9f4b47cf 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,41 @@ Changelog ========== +Release 2.4.0a0 (WIP) +-------------------------- + +Breaking Changes: +^^^^^^^^^^^^^^^^^ + +New Features: +^^^^^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ + +`SB3-Contrib`_ +^^^^^^^^^^^^^^ + +`RL Zoo`_ +^^^^^^^^^ + +`SBX`_ (SB3 + Jax) +^^^^^^^^^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ + +Documentation: +^^^^^^^^^^^^^^ +- Added ER-MRL to the project page + + Release 2.3.1 (2024-04-22) -------------------------- diff --git a/docs/misc/projects.rst b/docs/misc/projects.rst index 2b2e2405c..5f0c69710 100644 --- a/docs/misc/projects.rst +++ b/docs/misc/projects.rst @@ -239,3 +239,14 @@ Playing Pokemon Red with Reinforcement Learning. | Author: Peter Whidden | Github: https://github.com/PWhiddy/PokemonRedExperiments | Video: https://www.youtube.com/watch?v=DcYLT37ImBY + + +Evolving Reservoirs for Meta Reinforcement Learning +--------------------------------------------------- + +Meta-RL framework to optimize reservoir-like neural structures (special kind of RNNs), and integrate them to RL agents to improve their training. +It enables solving environments involving partial observability or locomotion (e.g MuJoCo), and optimizing reservoirs that can generalize to unseen tasks. + +| Authors: Corentin Léger, Gautier Hamon, Eleni Nisioti, Xavier Hinaut, Clément Moulin-Frier +| Github: https://github.com/corentinlger/ER-MRL +| Paper: https://arxiv.org/abs/2312.06695 diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 2bf1c1ccf..e96f44fb3 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.3.1 +2.4.0a0 From 35eccaf04fa011128f02eaecac6caab535686459 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Fri, 26 Apr 2024 12:12:04 +0200 Subject: [PATCH 17/30] Fix tensorboad video slow numpy->torch conversion (#1910) * fixed tb video docs * updated changelog * add comment on expected render() output * Update changelog.rst --------- Co-authored-by: Antonin RAFFIN --- docs/guide/tensorboard.rst | 6 +++++- docs/misc/changelog.rst | 7 +++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/guide/tensorboard.rst b/docs/guide/tensorboard.rst index 4ef1b496a..ba62e5e6b 100644 --- a/docs/guide/tensorboard.rst +++ b/docs/guide/tensorboard.rst @@ -192,6 +192,7 @@ Here is an example of how to render an episode and log the resulting video to Te import gymnasium as gym import torch as th + import numpy as np from stable_baselines3 import A2C from stable_baselines3.common.callbacks import BaseCallback @@ -226,6 +227,9 @@ Here is an example of how to render an episode and log the resulting video to Te :param _locals: A dictionary containing all local variables of the callback's scope :param _globals: A dictionary containing all global variables of the callback's scope """ + # We expect `render()` to return a uint8 array with values in [0, 255] or a float array + # with values in [0, 1], as described in + # https://pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_video screen = self._eval_env.render(mode="rgb_array") # PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image convention screens.append(screen.transpose(2, 0, 1)) @@ -239,7 +243,7 @@ Here is an example of how to render an episode and log the resulting video to Te ) self.logger.record( "trajectory/video", - Video(th.ByteTensor([screens]), fps=40), + Video(th.from_numpy(np.asarray([screens])), fps=40), exclude=("stdout", "log", "json", "csv"), ) return True diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index f9f4b47cf..4fa9043c8 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -35,8 +35,8 @@ Bug Fixes: Documentation: ^^^^^^^^^^^^^^ -- Added ER-MRL to the project page - +- Added ER-MRL to the project page (@corentinlger) +- Updated Tensorboard Logging Videos documentation (@NickLucche) Release 2.3.1 (2024-04-22) -------------------------- @@ -50,7 +50,6 @@ Documentation: - Updated SBX documentation (CrossQ and deprecated DroQ) - Updated RL Tips and Tricks section - Release 2.3.0 (2024-03-31) -------------------------- @@ -1641,4 +1640,4 @@ And all the contributors: @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger -@marekm4 @stagoverflow @rushitnshah @markscsmith +@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche From 285e01f64aa8ba4bd15aa339c45876d56ed0c3b4 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Sat, 27 Apr 2024 15:08:38 +0200 Subject: [PATCH 18/30] Hotfix: revert loading with `weights_only=True` (#1913) --- docs/misc/changelog.rst | 17 +++++++++++++++++ stable_baselines3/common/save_util.py | 3 ++- stable_baselines3/version.txt | 2 +- 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 4fa9043c8..203f15e3a 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -33,11 +33,23 @@ Others: Bug Fixes: ^^^^^^^^^^ +Documentation: +^^^^^^^^^^^^^^ + +Release 2.3.2 (2024-04-27) +-------------------------- + +Bug Fixes: +^^^^^^^^^^ +- Reverted ``torch.load()`` to be called ``weights_only=False`` as it caused loading issue with old version of PyTorch. + + Documentation: ^^^^^^^^^^^^^^ - Added ER-MRL to the project page (@corentinlger) - Updated Tensorboard Logging Videos documentation (@NickLucche) + Release 2.3.1 (2024-04-22) -------------------------- @@ -55,6 +67,11 @@ Release 2.3.0 (2024-03-31) **New defaults hyperparameters for DDPG, TD3 and DQN** +.. warning:: + + Because of ``weights_only=True``, this release breaks loading of policies when using PyTorch 1.13. + Please upgrade to PyTorch >= 2.0 or upgrade SB3 version (we reverted the change in SB3 2.3.2) + Breaking Changes: ^^^^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index 2d8652006..a85c9c2ec 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -447,7 +447,8 @@ def load_from_zip_file( file_content.seek(0) # Load the parameters with the right ``map_location``. # Remove ".pth" ending with splitext - th_object = th.load(file_content, map_location=device, weights_only=True) + # Note(antonin): we cannot use weights_only=True, as it breaks with PyTorch 1.13, see GH#1911 + th_object = th.load(file_content, map_location=device, weights_only=False) # "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138 if file_path == "pytorch_variables.pth" or file_path == "tensors.pth": # PyTorch variables (not state_dicts) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index e96f44fb3..f90b1afc0 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a0 +2.3.2 From 766b9e9f7ddd51c58d38fd090927c7242a249d2f Mon Sep 17 00:00:00 2001 From: Andrew James Date: Mon, 13 May 2024 10:28:23 -0500 Subject: [PATCH 19/30] Avoid torch type-error under torch.compile (#1922) * Avoid torch type-error under torch.compile * Update changelog and version * Update stable_baselines3/common/buffers.py Co-authored-by: Antonin RAFFIN --------- Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 1 + stable_baselines3/common/buffers.py | 2 +- stable_baselines3/version.txt | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 203f15e3a..16712fb79 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -14,6 +14,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- Cast type in compute gae method to avoid error when using torch compile (@amjames) `SB3-Contrib`_ ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 306b43571..651ecdb2d 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -424,7 +424,7 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarra last_gae_lam = 0 for step in reversed(range(self.buffer_size)): if step == self.buffer_size - 1: - next_non_terminal = 1.0 - dones + next_non_terminal = 1.0 - dones.astype(np.float32) next_values = last_values else: next_non_terminal = 1.0 - self.episode_starts[step + 1] diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index f90b1afc0..e96f44fb3 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.3.2 +2.4.0a0 From 4317c62598aea33251ab523932e8b51916c04476 Mon Sep 17 00:00:00 2001 From: Chris Schindlbeck Date: Wed, 15 May 2024 15:19:39 +0200 Subject: [PATCH 20/30] Fix various typos (#1926) * Fix various typos * Update changelog --------- Co-authored-by: Antonin Raffin --- README.md | 2 +- docs/guide/examples.rst | 2 +- docs/misc/changelog.rst | 9 +++++---- stable_baselines3/common/base_class.py | 2 +- stable_baselines3/common/callbacks.py | 2 +- stable_baselines3/common/env_checker.py | 2 +- stable_baselines3/common/monitor.py | 2 +- stable_baselines3/common/policies.py | 2 +- stable_baselines3/common/results_plotter.py | 2 +- stable_baselines3/common/running_mean_std.py | 2 +- stable_baselines3/common/torch_layers.py | 2 +- stable_baselines3/common/type_aliases.py | 2 +- stable_baselines3/common/vec_env/patch_gym.py | 2 +- stable_baselines3/common/vec_env/vec_normalize.py | 2 +- stable_baselines3/her/her_replay_buffer.py | 6 +++--- stable_baselines3/sac/policies.py | 2 +- stable_baselines3/td3/policies.py | 2 +- tests/test_buffers.py | 6 +++--- tests/test_cnn.py | 2 +- tests/test_dict_env.py | 2 +- tests/test_her.py | 2 +- tests/test_logger.py | 4 ++-- tests/test_vec_normalize.py | 2 +- 23 files changed, 32 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 018b8c6f3..78592bae8 100644 --- a/README.md +++ b/README.md @@ -192,7 +192,7 @@ All the following examples can be executed online using Google Colab notebooks: 1: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) GitHub repository. Actions `gym.spaces`: - * `Box`: A N-dimensional box that containes every point in the action space. + * `Box`: A N-dimensional box that contains every point in the action space. * `Discrete`: A list of possible actions, where each timestep only one of the actions can be used. * `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used. * `MultiBinary`: A list of possible actions, where each timestep any of the actions can be used in any combination. diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index a4729bfb3..67a477769 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -128,7 +128,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments :param env_id: the environment ID :param num_env: the number of environments you wish to have in subprocesses - :param seed: the inital seed for RNG + :param seed: the initial seed for RNG :param rank: index of the subprocess """ def _init(): diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 16712fb79..a4c333e08 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -30,6 +30,7 @@ Deprecations: Others: ^^^^^^^ +- Fix various typos (@cschindlbeck) Bug Fixes: ^^^^^^^^^^ @@ -403,7 +404,7 @@ Breaking Changes: ^^^^^^^^^^^^^^^^^ - Removed shared layers in ``mlp_extractor`` (@AlexPasqua) - Refactored ``StackedObservations`` (it now handles dict obs, ``StackedDictObservations`` was removed) -- You must now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()`` +- You must now explicitly pass a ``features_extractor`` parameter when calling ``extract_features()`` - Dropped offline sampling for ``HerReplayBuffer`` - As ``HerReplayBuffer`` was refactored to support multiprocessing, previous replay buffer are incompatible with this new version - ``HerReplayBuffer`` doesn't require a ``max_episode_length`` anymore @@ -535,7 +536,7 @@ Bug Fixes: Deprecations: ^^^^^^^^^^^^^ -- You should now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()`` +- You should now explicitly pass a ``features_extractor`` parameter when calling ``extract_features()`` - Deprecated shared layers in ``MlpExtractor`` (@AlexPasqua) Others: @@ -746,7 +747,7 @@ Bug Fixes: - Fixed a bug in ``HumanOutputFormat``. Distinct keys truncated to the same prefix would overwrite each others value, resulting in only one being output. This now raises an error (this should only affect a small fraction of use cases with very long keys.) -- Routing all the ``nn.Module`` calls through implicit rather than explict forward as per pytorch guidelines (@manuel-delverme) +- Routing all the ``nn.Module`` calls through implicit rather than explicit forward as per pytorch guidelines (@manuel-delverme) - Fixed a bug in ``VecNormalize`` where error occurs when ``norm_obs`` is set to False for environment with dictionary observation (@buoyancy99) - Set default ``env`` argument to ``None`` in ``HerReplayBuffer.sample`` (@qgallouedec) - Fix ``batch_size`` typing in ``DQN`` (@qgallouedec) @@ -1658,4 +1659,4 @@ And all the contributors: @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger -@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche +@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index e6c7d3cfc..054e58a1f 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -48,7 +48,7 @@ def maybe_make_env(env: Union[GymEnv, str], verbose: int) -> GymEnv: """If env is a string, make the environment; otherwise, return env. :param env: The environment to learn from. - :param verbose: Verbosity level: 0 for no output, 1 for indicating if envrironment is created + :param verbose: Verbosity level: 0 for no output, 1 for indicating if environment is created :return A Gym (vector) environment. """ if isinstance(env, str): diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index 2898df8f4..48b6011d1 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -606,7 +606,7 @@ def __init__(self, max_episodes: int, verbose: int = 0): self.n_episodes = 0 def _init_callback(self) -> None: - # At start set total max according to number of envirnments + # At start set total max according to number of environments self._total_max_episodes = self.max_episodes * self.training_env.num_envs def _on_step(self) -> bool: diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index f24c86ec9..090d609ba 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -397,7 +397,7 @@ def _check_render(env: gym.Env, warn: bool = False) -> None: # pragma: no cover "you may have trouble when calling `.render()`" ) - # Only check currrent render mode + # Only check current render mode if env.render_mode: env.render() env.close() diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index 5253954e8..fb8ce33c6 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -189,7 +189,7 @@ def __init__( filename = os.path.realpath(filename) # Create (if any) missing filename directories os.makedirs(os.path.dirname(filename), exist_ok=True) - # Append mode when not overridding existing file + # Append mode when not overriding existing file mode = "w" if override_existing else "a" # Prevent newline issue on Windows, see GH issue #692 self.file_handler = open(filename, f"{mode}t", newline="\n") diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index e4d62ef0f..3c9b14aaa 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -922,7 +922,7 @@ class ContinuousCritic(BaseModel): By default, it creates two critic networks used to reduce overestimation thanks to clipped Q-learning (cf TD3 paper). - :param observation_space: Obervation space + :param observation_space: Observation space :param action_space: Action space :param net_arch: Network architecture :param features_extractor: Network to extract features diff --git a/stable_baselines3/common/results_plotter.py b/stable_baselines3/common/results_plotter.py index 1324557d1..f4c1a7a05 100644 --- a/stable_baselines3/common/results_plotter.py +++ b/stable_baselines3/common/results_plotter.py @@ -46,7 +46,7 @@ def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callabl def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]: """ - Decompose a data frame variable to x ans ys + Decompose a data frame variable to x and ys :param data_frame: the input data :param x_axis: the axis for the x and y output diff --git a/stable_baselines3/common/running_mean_std.py b/stable_baselines3/common/running_mean_std.py index 9dfa4b84c..ac3538c50 100644 --- a/stable_baselines3/common/running_mean_std.py +++ b/stable_baselines3/common/running_mean_std.py @@ -6,7 +6,7 @@ class RunningMeanStd: def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()): """ - Calulates the running mean and std of a data stream + Calculates the running mean and std of a data stream https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm :param epsilon: helps with arithmetic issues diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index ad6c7eef1..bb3ba5de8 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -189,7 +189,7 @@ def __init__( # save dimensions of layers in policy and value nets if isinstance(net_arch, dict): - # Note: if key is not specificed, assume linear network + # Note: if key is not specified, assume linear network pi_layers_dims = net_arch.get("pi", []) # Layer sizes of the policy network vf_layers_dims = net_arch.get("vf", []) # Layer sizes of the value network else: diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 85d09066e..042c66f9c 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -24,7 +24,7 @@ PyTorchObs = Union[th.Tensor, TensorDict] # A schedule takes the remaining progress as input -# and ouputs a scalar (e.g. learning rate, clip range, ...) +# and outputs a scalar (e.g. learning rate, clip range, ...) Schedule = Callable[[float], float] diff --git a/stable_baselines3/common/vec_env/patch_gym.py b/stable_baselines3/common/vec_env/patch_gym.py index 2da76a9b2..6ba655ebf 100644 --- a/stable_baselines3/common/vec_env/patch_gym.py +++ b/stable_baselines3/common/vec_env/patch_gym.py @@ -71,7 +71,7 @@ def _convert_space(space: Union["gym.Space", gymnasium.Space]) -> gymnasium.Spac :return: Patched space (gymnasium Space) """ - # Gymnasium space, no convertion to be done + # Gymnasium space, no conversion to be done if isinstance(space, gymnasium.Space): return space diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py index 391ce342d..ab1d8403a 100644 --- a/stable_baselines3/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -111,7 +111,7 @@ def _sanity_checks(self) -> None: raise ValueError( f"VecNormalize only supports `gym.spaces.Box` observation spaces but {obs_key} " f"is of type {self.observation_space.spaces[obs_key]}. " - "You should probably explicitely pass the observation keys " + "You should probably explicitly pass the observation keys " " that should be normalized via the `norm_obs_keys` parameter." ) diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index 5f0765884..579c6ebf1 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -255,7 +255,7 @@ def _get_real_samples( Get the samples corresponding to the batch and environment indices. :param batch_indices: Indices of the transitions - :param env_indices: Indices of the envrionments + :param env_indices: Indices of the environments :param env: associated gym VecEnv to normalize the observations/rewards when sampling, defaults to None :return: Samples @@ -294,7 +294,7 @@ def _get_virtual_samples( Get the samples, sample new desired goals and compute new rewards. :param batch_indices: Indices of the transitions - :param env_indices: Indices of the envrionments + :param env_indices: Indices of the environments :param env: associated gym VecEnv to normalize the observations/rewards when sampling, defaults to None :return: Samples, with new desired goals and new rewards @@ -357,7 +357,7 @@ def _sample_goals(self, batch_indices: np.ndarray, env_indices: np.ndarray) -> n Sample goals based on goal_selection_strategy. :param batch_indices: Indices of the transitions - :param env_indices: Indices of the envrionments + :param env_indices: Indices of the environments :return: Sampled goals """ batch_ep_start = self.ep_start[batch_indices, env_indices] diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 97d0ad94e..6185e2992 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -26,7 +26,7 @@ class Actor(BasePolicy): """ Actor network (policy) for SAC. - :param observation_space: Obervation space + :param observation_space: Observation space :param action_space: Action space :param net_arch: Network architecture :param features_extractor: Network to extract features diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index dda6cb31a..a15be0396 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -21,7 +21,7 @@ class Actor(BasePolicy): """ Actor network (policy) for TD3. - :param observation_space: Obervation space + :param observation_space: Observation space :param action_space: Action space :param net_arch: Network architecture :param features_extractor: Network to extract features diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 2ea366aff..da6b44a34 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -74,7 +74,7 @@ def step(self, action): @pytest.mark.parametrize("env_cls", [DummyEnv, DummyDictEnv]) def test_env(env_cls): # Check the env used for testing - # Do not warn for assymetric space + # Do not warn for asymmetric space check_env(env_cls(), warn=False, skip_render_check=True) @@ -86,7 +86,7 @@ def test_replay_buffer_normalization(replay_buffer_cls): buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device="cpu") - # Interract and store transitions + # Interact and store transitions env.reset() obs = env.get_original_obs() for _ in range(100): @@ -125,7 +125,7 @@ def test_device_buffer(replay_buffer_cls, device): buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device=device) - # Interract and store transitions + # Interact and store transitions obs = env.reset() for _ in range(100): action = env.action_space.sample() diff --git a/tests/test_cnn.py b/tests/test_cnn.py index e32438c27..4ff31486b 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -161,7 +161,7 @@ def test_features_extractor_target_net(model_class, share_features_extractor): if model_class == TD3: assert id(model.policy.actor_target.features_extractor) != id(model.policy.critic_target.features_extractor) - # Critic and target should be equal at the begginning of training + # Critic and target should be equal at the beginning of training params_should_match(model.critic.parameters(), model.critic_target.parameters()) # TD3 has also a target actor net diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 14777452e..f093e47e7 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -326,7 +326,7 @@ def test_vec_normalize(model_class): def test_dict_nested(): """ - Make sure we throw an appropiate error with nested Dict observation spaces + Make sure we throw an appropriate error with nested Dict observation spaces """ # Test without manual wrapping to vec-env env = DummyDictEnv(nested_dict_obs=True) diff --git a/tests/test_her.py b/tests/test_her.py index 79b0ac9c6..cb8bfb10f 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -384,7 +384,7 @@ def env_fn(): # for all episodes that are not finished before truncate_last_trajectory: timeouts should be 1 if handle_timeout_termination: assert (replay_buffer.timeouts[pos - 1, env_idx_not_finished] == 1).all() - # episode length sould be != 0 -> episode can be sampled + # episode length should be != 0 -> episode can be sampled assert (replay_buffer.ep_length[pos - 1] != 0).all() # replay buffer should not have changed after truncate_last_trajectory (except dones[pos-1]) diff --git a/tests/test_logger.py b/tests/test_logger.py index dfd9e5567..dfa3691ed 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -479,7 +479,7 @@ def get_printed(self) -> str: class DummySuccessEnv(gym.Env): """ - Create a dummy success environment that returns wether True or False for info['is_success'] + Create a dummy success environment that returns whether True or False for info['is_success'] at the end of an episode according to its dummy successes list """ @@ -536,7 +536,7 @@ def test_rollout_success_rate_on_policy_algorithm(tmp_path): Test if the rollout/success_rate information is correctly logged with on policy algorithms To do so, create a dummy environment that takes as argument dummy successes (i.e when an episode) - is going to be successfull or not. + is going to be successful or not. """ STATS_WINDOW_SIZE = 10 diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 2b30d5ad1..b7d71b748 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -484,7 +484,7 @@ def test_non_dict_obs_keys(): with pytest.raises(ValueError, match=".*is applicable only.*"): _make_warmstart(lambda: DummyRewardEnv(), norm_obs_keys=["key"]) - with pytest.raises(ValueError, match=".* explicitely pass the observation keys.*"): + with pytest.raises(ValueError, match=".* explicitly pass the observation keys.*"): _make_warmstart(lambda: DummyMixedDictEnv()) # Ignore Discrete observation key From 6c00565778e5815e4589afc7499aafbd020535ae Mon Sep 17 00:00:00 2001 From: Ole Petersen <56505957+peteole@users.noreply.github.com> Date: Wed, 15 May 2024 15:59:32 +0200 Subject: [PATCH 21/30] Fix memory leak in base_class.py (#1908) * Fix memory leak in base_class.py Loading the data return value is not necessary since it is unused. Loading the data causes a memory leak through the ep_info_buffer variable. I found this while loading a PPO learner from storage on a multi-GPU system since the ep_info_buffer is loaded to the memory location it was on while it was saved to disk, instead of the target loading location, and is then not cleaned up. * Update changelog.rst * Update changelog --------- Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 8 +++++--- stable_baselines3/common/base_class.py | 2 +- stable_baselines3/version.txt | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index a4c333e08..d77388534 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a0 (WIP) +Release 2.4.0a1 (WIP) -------------------------- Breaking Changes: @@ -14,6 +14,8 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- Fixed memory leak when loading learner from storage, ``set_parameters()`` does not try to load the object data anymore + and only loads the PyTorch parameters (@peteole) - Cast type in compute gae method to avoid error when using torch compile (@amjames) `SB3-Contrib`_ @@ -30,7 +32,7 @@ Deprecations: Others: ^^^^^^^ -- Fix various typos (@cschindlbeck) +- Fixed various typos (@cschindlbeck) Bug Fixes: ^^^^^^^^^^ @@ -1659,4 +1661,4 @@ And all the contributors: @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger -@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck +@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 054e58a1f..4be61d65e 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -592,7 +592,7 @@ def set_parameters( if isinstance(load_path_or_dict, dict): params = load_path_or_dict else: - _, params, _ = load_from_zip_file(load_path_or_dict, device=device) + _, params, _ = load_from_zip_file(load_path_or_dict, device=device, load_data=False) # Keep track which objects were updated. # `_get_torch_save_params` returns [params, other_pytorch_variables]. diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index e96f44fb3..48adc0106 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a0 +2.4.0a1 From 0b06d8ab203df5c0fe36bbc395454fba186a4363 Mon Sep 17 00:00:00 2001 From: Joe Ksiazek Date: Wed, 5 Jun 2024 11:27:40 -0400 Subject: [PATCH 22/30] Fix error when loading a model that has net_arch manually set to None (#1937) * Fix loading a model with net_arch=None * Remove redundant get * Dummy commit * Add to contributors * Update test and version --------- Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 5 +++-- stable_baselines3/common/base_class.py | 7 +++---- stable_baselines3/version.txt | 2 +- tests/test_save_load.py | 12 ++++++++++++ 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index d77388534..758615c39 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a1 (WIP) +Release 2.4.0a2 (WIP) -------------------------- Breaking Changes: @@ -17,6 +17,7 @@ Bug Fixes: - Fixed memory leak when loading learner from storage, ``set_parameters()`` does not try to load the object data anymore and only loads the PyTorch parameters (@peteole) - Cast type in compute gae method to avoid error when using torch compile (@amjames) +- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122) `SB3-Contrib`_ ^^^^^^^^^^^^^^ @@ -1661,4 +1662,4 @@ And all the contributors: @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger -@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole +@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 4be61d65e..b2c967405 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -692,10 +692,9 @@ def load( # noqa: C901 if "device" in data["policy_kwargs"]: del data["policy_kwargs"]["device"] # backward compatibility, convert to new format - if "net_arch" in data["policy_kwargs"] and len(data["policy_kwargs"]["net_arch"]) > 0: - saved_net_arch = data["policy_kwargs"]["net_arch"] - if isinstance(saved_net_arch, list) and isinstance(saved_net_arch[0], dict): - data["policy_kwargs"]["net_arch"] = saved_net_arch[0] + saved_net_arch = data["policy_kwargs"].get("net_arch") + if saved_net_arch and isinstance(saved_net_arch, list) and isinstance(saved_net_arch[0], dict): + data["policy_kwargs"]["net_arch"] = saved_net_arch[0] if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]: raise ValueError( diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 48adc0106..e828d3c3d 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a1 +2.4.0a2 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 0162e3650..c7df7b26f 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -797,3 +797,15 @@ def test_cast_lr_schedule(tmp_path): model = PPO.load(tmp_path / "ppo.zip") assert type(model.lr_schedule(1.0)) is float # noqa: E721 assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0)) + + +def test_save_load_net_arch_none(tmp_path): + """ + Test that the model is loaded correctly when net_arch is manually set to None. + See GH#1928 + """ + PPO("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=None)).save(tmp_path / "ppo.zip") + model = PPO.load(tmp_path / "ppo.zip") + # None has been replaced by the default net arch + assert model.policy.net_arch is not None + os.remove(tmp_path / "ppo.zip") From 4efee92fbad70f85aa094e27bd0a740274121795 Mon Sep 17 00:00:00 2001 From: will-maclean <41996719+will-maclean@users.noreply.github.com> Date: Fri, 7 Jun 2024 22:07:28 +1000 Subject: [PATCH 23/30] Set CallbackList children's parent correctly (#1939) * Fixing #1791 * Update test and version * Add test for callback after eval * Fix mypy error * Remove tqdm warnings --------- Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 7 ++++--- pyproject.toml | 2 ++ stable_baselines3/common/buffers.py | 2 +- stable_baselines3/common/callbacks.py | 4 ++++ stable_baselines3/common/policies.py | 2 +- stable_baselines3/version.txt | 2 +- tests/test_callbacks.py | 26 ++++++++++++++++++++++++++ 7 files changed, 39 insertions(+), 6 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 758615c39..d6df00956 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a2 (WIP) +Release 2.4.0a3 (WIP) -------------------------- Breaking Changes: @@ -17,7 +17,8 @@ Bug Fixes: - Fixed memory leak when loading learner from storage, ``set_parameters()`` does not try to load the object data anymore and only loads the PyTorch parameters (@peteole) - Cast type in compute gae method to avoid error when using torch compile (@amjames) -- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122) +- ``CallbackList`` now sets the ``.parent`` attribute of child callbacks to its own ``.parent``. (will-maclean) +- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122) `SB3-Contrib`_ ^^^^^^^^^^^^^^ @@ -1662,4 +1663,4 @@ And all the contributors: @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger -@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 +@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean diff --git a/pyproject.toml b/pyproject.toml index ce0a14e0f..8e20ffe00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,8 @@ filterwarnings = [ "ignore::DeprecationWarning:tensorboard", # Gymnasium warnings "ignore::UserWarning:gymnasium", + # tqdm warning about rich being experimental + "ignore:rich is experimental" ] markers = [ "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')" diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 651ecdb2d..b2fc5a710 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -419,7 +419,7 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarra :param dones: if the last step was a terminal step (one bool for each env). """ # Convert to numpy - last_values = last_values.clone().cpu().numpy().flatten() + last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment] last_gae_lam = 0 for step in reversed(range(self.buffer_size)): diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index 48b6011d1..c7841866b 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -204,6 +204,10 @@ def _init_callback(self) -> None: for callback in self.callbacks: callback.init_callback(self.model) + # Fix for https://github.com/DLR-RM/stable-baselines3/issues/1791 + # pass through the parent callback to all children + callback.parent = self.parent + def _on_training_start(self) -> None: for callback in self.callbacks: callback.on_training_start(self.locals, self.globals) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 3c9b14aaa..f9c4285dc 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -367,7 +367,7 @@ def predict( with th.no_grad(): actions = self._predict(obs_tensor, deterministic=deterministic) # Convert to numpy, and reshape to the original action shape - actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc] + actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc, assignment] if isinstance(self.action_space, spaces.Box): if self.squash_output: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index e828d3c3d..fdd5a5f23 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a2 +2.4.0a3 diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index d159c43e8..ffc37320f 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -264,3 +264,29 @@ def test_checkpoint_additional_info(tmp_path): model = DQN.load(checkpoint_dir / "rl_model_200_steps.zip") model.load_replay_buffer(checkpoint_dir / "rl_model_replay_buffer_200_steps.pkl") VecNormalize.load(checkpoint_dir / "rl_model_vecnormalize_200_steps.pkl", dummy_vec_env) + + +def test_eval_callback_chaining(tmp_path): + class DummyCallback(BaseCallback): + def _on_step(self): + # Check that the parent callback is an EvalCallback + assert isinstance(self.parent, EvalCallback) + assert hasattr(self.parent, "best_mean_reward") + return True + + stop_on_threshold_callback = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1) + + eval_callback = EvalCallback( + gym.make("Pendulum-v1"), + best_model_save_path=tmp_path, + log_path=tmp_path, + eval_freq=32, + deterministic=True, + render=False, + callback_on_new_best=CallbackList([DummyCallback(), stop_on_threshold_callback]), + callback_after_eval=CallbackList([DummyCallback()]), + warn=False, + ) + + model = PPO("MlpPolicy", "Pendulum-v1", n_steps=64, n_epochs=1) + model.learn(64, callback=eval_callback) From 24ebf1a1df5f4d51d8163344c840e63e4994090a Mon Sep 17 00:00:00 2001 From: Dominik Baron Date: Sat, 29 Jun 2024 20:07:32 +0200 Subject: [PATCH 24/30] Remove unnecessary SDE resampling in PPO update (#1933) * Remove unnecessary SDE resampling in PPO update * Update changelog.rst * Update version * Update PyTorch version on CI * Update ruff * Limit NumPy version * Reformat --------- Co-authored-by: Antonin RAFFIN --- .github/workflows/ci.yml | 2 +- Makefile | 2 +- docs/misc/changelog.rst | 6 +++++- setup.py | 2 +- stable_baselines3/ppo/ppo.py | 4 ---- stable_baselines3/version.txt | 2 +- tests/test_save_load.py | 4 ++-- 7 files changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b1078cd28..0efc16e56 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,7 +32,7 @@ jobs: run: | python -m pip install --upgrade pip # cpu version of pytorch - pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu + pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu # Install Atari Roms pip install autorom diff --git a/Makefile b/Makefile index e0f6b2b0c..d9b922515 100644 --- a/Makefile +++ b/Makefile @@ -20,7 +20,7 @@ lint: # see https://www.flake8rules.com/ ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full # exit-zero treats all errors as warnings. - ruff check ${LINT_PATHS} --exit-zero + ruff check ${LINT_PATHS} --exit-zero --output-format=concise format: # Sort imports diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index d6df00956..8df321129 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a3 (WIP) +Release 2.4.0a4 (WIP) -------------------------- Breaking Changes: @@ -19,6 +19,7 @@ Bug Fixes: - Cast type in compute gae method to avoid error when using torch compile (@amjames) - ``CallbackList`` now sets the ``.parent`` attribute of child callbacks to its own ``.parent``. (will-maclean) - Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122) +- Set requirement numpy<2.0 until PyTorch is compatible (https://github.com/pytorch/pytorch/issues/107302) `SB3-Contrib`_ ^^^^^^^^^^^^^^ @@ -35,6 +36,8 @@ Deprecations: Others: ^^^^^^^ - Fixed various typos (@cschindlbeck) +- Remove unnecessary SDE noise resampling in PPO update (@brn-dev) +- Updated PyTorch version on CI to 2.3.1 Bug Fixes: ^^^^^^^^^^ @@ -1664,3 +1667,4 @@ And all the contributors: @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger @marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean +@brn-dev diff --git a/setup.py b/setup.py index 161539af1..9d56dfd77 100644 --- a/setup.py +++ b/setup.py @@ -101,7 +101,7 @@ package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ "gymnasium>=0.28.1,<0.30", - "numpy>=1.20", + "numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302 "torch>=1.13", # For saving models "cloudpickle", diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index ea7cf5ed4..52ee2eb64 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -210,10 +210,6 @@ def train(self) -> None: # Convert discrete action from float to long actions = rollout_data.actions.long().flatten() - # Re-sample the noise matrix because the log_std has changed - if self.use_sde: - self.policy.reset_noise(self.batch_size) - values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) values = values.flatten() # Normalize advantage diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index fdd5a5f23..2d22b1587 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a3 +2.4.0a4 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index c7df7b26f..5dc6ca7bf 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -791,11 +791,11 @@ def test_cast_lr_schedule(tmp_path): # Note: for recent version of numpy, np.float64 is a subclass of float # so we need to use type here # assert isinstance(model.lr_schedule(1.0), float) - assert type(model.lr_schedule(1.0)) is float # noqa: E721 + assert type(model.lr_schedule(1.0)) is float assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0)) model.save(tmp_path / "ppo.zip") model = PPO.load(tmp_path / "ppo.zip") - assert type(model.lr_schedule(1.0)) is float # noqa: E721 + assert type(model.lr_schedule(1.0)) is float assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0)) From 0eebde7ca129632e2d256b26eb405fab481804d0 Mon Sep 17 00:00:00 2001 From: Sahit Chintalapudi Date: Fri, 5 Jul 2024 09:00:48 -0400 Subject: [PATCH 25/30] Fix typo in examples.rst (#1962) The variable `env` is not defined. The gym env we want to change is `vec_env` --- docs/guide/examples.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 67a477769..9f5423162 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -436,7 +436,7 @@ will compute a running average and standard deviation of input features (it can log_dir = "/tmp/" model.save(log_dir + "ppo_halfcheetah") stats_path = os.path.join(log_dir, "vec_normalize.pkl") - env.save(stats_path) + vec_env.save(stats_path) # To demonstrate loading del model, vec_env From d8148deeaad3dbd1fb2b601e6f21d71f210366b1 Mon Sep 17 00:00:00 2001 From: Corentin <111868204+corentinlger@users.noreply.github.com> Date: Fri, 5 Jul 2024 19:07:55 +0200 Subject: [PATCH 26/30] Updated DQN optimizer input to only include q_network parameters as input (#1963) * Updated DQN optimizer input to only include q_network parameters * Update version --------- Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 3 ++- stable_baselines3/dqn/policies.py | 2 +- stable_baselines3/version.txt | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 8df321129..78eb2bd0e 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a4 (WIP) +Release 2.4.0a5 (WIP) -------------------------- Breaking Changes: @@ -20,6 +20,7 @@ Bug Fixes: - ``CallbackList`` now sets the ``.parent`` attribute of child callbacks to its own ``.parent``. (will-maclean) - Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122) - Set requirement numpy<2.0 until PyTorch is compatible (https://github.com/pytorch/pytorch/issues/107302) +- Updated DQN optimizer input to only include q_network parameters, removing the target_q_network ones (@corentinlger) `SB3-Contrib`_ ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index 9d2cf94df..bfefc8137 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -167,7 +167,7 @@ def _build(self, lr_schedule: Schedule) -> None: # Setup optimizer with initial learning rate self.optimizer = self.optimizer_class( # type: ignore[call-arg] - self.parameters(), + self.q_net.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs, ) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 2d22b1587..a1fd35b5f 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a4 +2.4.0a5 From 1a69fc831414626cbbcf13343c6e78d9accb9104 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 15 Jul 2024 23:57:24 +0200 Subject: [PATCH 27/30] Update examples.rst (#1969) --- docs/guide/examples.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 9f5423162..32158172b 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -179,9 +179,9 @@ Multiprocessing with off-policy algorithms vec_env = make_vec_env("Pendulum-v0", n_envs=4, seed=0) - # We collect 4 transitions per call to `ènv.step()` - # and performs 2 gradient steps per call to `ènv.step()` - # if gradient_steps=-1, then we would do 4 gradients steps per call to `ènv.step()` + # We collect 4 transitions per call to `env.step()` + # and performs 2 gradient steps per call to `env.step()` + # if gradient_steps=-1, then we would do 4 gradients steps per call to `env.step()` model = SAC("MlpPolicy", vec_env, train_freq=1, gradient_steps=2, verbose=1) model.learn(total_timesteps=10_000) From 000544cc1fe6a1c1ec80c125dadad11ad49e1473 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 22 Jul 2024 13:42:33 +0200 Subject: [PATCH 28/30] Add support for pre and post linear modules in `create_mlp` (#1975) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add support for pre and post linear modules in `create_mlp` * Disable mypy for python 3.8 * Reformat toml file * Update docstring Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Add some comments --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- .github/workflows/ci.yml | 72 ++++++++++++------------ docs/misc/changelog.rst | 3 +- pyproject.toml | 32 ++++++----- stable_baselines3/common/torch_layers.py | 53 ++++++++++++++--- stable_baselines3/version.txt | 2 +- tests/test_custom_policy.py | 56 ++++++++++++++++++ 6 files changed, 157 insertions(+), 61 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0efc16e56..822e0cb3f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,9 +5,9 @@ name: CI on: push: - branches: [ master ] + branches: [master] pull_request: - branches: [ master ] + branches: [master] jobs: build: @@ -23,38 +23,40 @@ jobs: python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - # cpu version of pytorch - pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + # cpu version of pytorch + pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu - # Install Atari Roms - pip install autorom - wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 - base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz - AutoROM --accept-license --source-file Roms.tar.gz + # Install Atari Roms + pip install autorom + wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 + base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz + AutoROM --accept-license --source-file Roms.tar.gz - pip install .[extra_no_roms,tests,docs] - # Use headless version - pip install opencv-python-headless - - name: Lint with ruff - run: | - make lint - - name: Build the doc - run: | - make doc - - name: Check codestyle - run: | - make check-codestyle - - name: Type check - run: | - make type - - name: Test with pytest - run: | - make pytest + pip install .[extra_no_roms,tests,docs] + # Use headless version + pip install opencv-python-headless + - name: Lint with ruff + run: | + make lint + - name: Build the doc + run: | + make doc + - name: Check codestyle + run: | + make check-codestyle + - name: Type check + run: | + make type + # Do not run for python 3.8 (mypy internal error) + if: matrix.python-version != '3.8' + - name: Test with pytest + run: | + make pytest diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 78eb2bd0e..31ff99d09 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a5 (WIP) +Release 2.4.0a6 (WIP) -------------------------- Breaking Changes: @@ -11,6 +11,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ) Bug Fixes: ^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index 8e20ffe00..dd435a33e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,10 +13,10 @@ ignore = ["B028", "RUF013"] [tool.ruff.lint.per-file-ignores] # Default implementation in abstract methods -"./stable_baselines3/common/callbacks.py"= ["B027"] -"./stable_baselines3/common/noise.py"= ["B027"] +"./stable_baselines3/common/callbacks.py" = ["B027"] +"./stable_baselines3/common/noise.py" = ["B027"] # ClassVar, implicit optional check not needed for tests -"./tests/*.py"= ["RUF012", "RUF013"] +"./tests/*.py" = ["RUF012", "RUF013"] [tool.ruff.lint.mccabe] @@ -37,9 +37,7 @@ exclude = """(?x)( [tool.pytest.ini_options] # Deterministic ordering for tests; useful for pytest-xdist. -env = [ - "PYTHONHASHSEED=0" -] +env = ["PYTHONHASHSEED=0"] filterwarnings = [ # Tensorboard warnings @@ -47,23 +45,27 @@ filterwarnings = [ # Gymnasium warnings "ignore::UserWarning:gymnasium", # tqdm warning about rich being experimental - "ignore:rich is experimental" + "ignore:rich is experimental", ] markers = [ - "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')" + "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')", ] [tool.coverage.run] disable_warnings = ["couldnt-parse"] branch = false omit = [ - "tests/*", - "setup.py", - # Require graphical interface - "stable_baselines3/common/results_plotter.py", - # Require ffmpeg - "stable_baselines3/common/vec_env/vec_video_recorder.py", + "tests/*", + "setup.py", + # Require graphical interface + "stable_baselines3/common/results_plotter.py", + # Require ffmpeg + "stable_baselines3/common/vec_env/vec_video_recorder.py", ] [tool.coverage.report] -exclude_lines = [ "pragma: no cover", "raise NotImplementedError()", "if typing.TYPE_CHECKING:"] +exclude_lines = [ + "pragma: no cover", + "raise NotImplementedError()", + "if typing.TYPE_CHECKING:", +] diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index bb3ba5de8..234b91551 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple, Type, Union +from typing import Dict, List, Optional, Tuple, Type, Union import gymnasium as gym import torch as th @@ -14,7 +14,7 @@ class BaseFeaturesExtractor(nn.Module): """ Base class that represents a features extractor. - :param observation_space: + :param observation_space: The observation space of the environment :param features_dim: Number of features extracted. """ @@ -26,6 +26,7 @@ def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None: @property def features_dim(self) -> int: + """The number of features that the extractor outputs.""" return self._features_dim @@ -34,7 +35,7 @@ class FlattenExtractor(BaseFeaturesExtractor): Feature extract that flatten the input. Used as a placeholder when feature extraction is not needed. - :param observation_space: + :param observation_space: The observation space of the environment """ def __init__(self, observation_space: gym.Space) -> None: @@ -52,7 +53,7 @@ class NatureCNN(BaseFeaturesExtractor): "Human-level control through deep reinforcement learning." Nature 518.7540 (2015): 529-533. - :param observation_space: + :param observation_space: The observation space of the environment :param features_dim: Number of features extracted. This corresponds to the number of unit for the last layer. :param normalized_image: Whether to assume that the image is already normalized @@ -113,13 +114,15 @@ def create_mlp( activation_fn: Type[nn.Module] = nn.ReLU, squash_output: bool = False, with_bias: bool = True, + pre_linear_modules: Optional[List[Type[nn.Module]]] = None, + post_linear_modules: Optional[List[Type[nn.Module]]] = None, ) -> List[nn.Module]: """ Create a multi layer perceptron (MLP), which is a collection of fully-connected layers each followed by an activation function. :param input_dim: Dimension of the input vector - :param output_dim: + :param output_dim: Dimension of the output (last layer, for instance, the number of actions) :param net_arch: Architecture of the neural net It represents the number of units per layer. The length of this list is the number of layers. @@ -128,20 +131,52 @@ def create_mlp( :param squash_output: Whether to squash the output using a Tanh activation function :param with_bias: If set to False, the layers will not learn an additive bias - :return: + :param pre_linear_modules: List of nn.Module to add before the linear layers. + These modules should maintain the input tensor dimension (e.g. BatchNorm). + The number of input features is passed to the module's constructor. + Compared to post_linear_modules, they are used before the output layer (output_dim > 0). + :param post_linear_modules: List of nn.Module to add after the linear layers + (and before the activation function). These modules should maintain the input + tensor dimension (e.g. Dropout, LayerNorm). They are not used after the + output layer (output_dim > 0). The number of input features is passed to + the module's constructor. + :return: The list of layers of the neural network """ + pre_linear_modules = pre_linear_modules or [] + post_linear_modules = post_linear_modules or [] + + modules = [] if len(net_arch) > 0: - modules = [nn.Linear(input_dim, net_arch[0], bias=with_bias), activation_fn()] - else: - modules = [] + # BatchNorm maintains input dim + for module in pre_linear_modules: + modules.append(module(input_dim)) + + modules.append(nn.Linear(input_dim, net_arch[0], bias=with_bias)) + + # LayerNorm, Dropout maintain output dim + for module in post_linear_modules: + modules.append(module(net_arch[0])) + + modules.append(activation_fn()) for idx in range(len(net_arch) - 1): + for module in pre_linear_modules: + modules.append(module(net_arch[idx])) + modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1], bias=with_bias)) + + for module in post_linear_modules: + modules.append(module(net_arch[idx + 1])) + modules.append(activation_fn()) if output_dim > 0: last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim + # Only add BatchNorm before output layer + for module in pre_linear_modules: + modules.append(module(last_layer_dim)) + modules.append(nn.Linear(last_layer_dim, output_dim, bias=with_bias)) if squash_output: modules.append(nn.Tanh()) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index a1fd35b5f..464a5c4dc 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a5 +2.4.0a6 diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index 1f89b23d6..e92ffe8b7 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -1,8 +1,10 @@ import pytest import torch as th +import torch.nn as nn from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike +from stable_baselines3.common.torch_layers import create_mlp @pytest.mark.parametrize( @@ -62,3 +64,57 @@ def test_tf_like_rmsprop_optimizer(): def test_dqn_custom_policy(): policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32]) _ = DQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, learning_starts=100).learn(300) + + +def test_create_mlp(): + net = create_mlp(4, 2, net_arch=[16, 8], squash_output=True) + # We cannot compare the network directly because the modules have different ids + # assert net == [nn.Linear(4, 16), nn.ReLU(), nn.Linear(16, 8), nn.ReLU(), nn.Linear(8, 2), + # nn.Tanh()] + assert len(net) == 6 + assert isinstance(net[0], nn.Linear) + assert net[0].in_features == 4 + assert net[0].out_features == 16 + assert isinstance(net[1], nn.ReLU) + assert isinstance(net[2], nn.Linear) + assert isinstance(net[4], nn.Linear) + assert net[4].in_features == 8 + assert net[4].out_features == 2 + assert isinstance(net[5], nn.Tanh) + + # Linear network + net = create_mlp(4, -1, net_arch=[]) + assert net == [] + + # No output layer, with custom activation function + net = create_mlp(6, -1, net_arch=[8], activation_fn=nn.Tanh) + # assert net == [nn.Linear(6, 8), nn.Tanh()] + assert len(net) == 2 + assert isinstance(net[0], nn.Linear) + assert net[0].in_features == 6 + assert net[0].out_features == 8 + assert isinstance(net[1], nn.Tanh) + + # Using pre-linear and post-linear modules + pre_linear = [nn.BatchNorm1d] + post_linear = [nn.LayerNorm] + net = create_mlp(6, 2, net_arch=[8, 12], pre_linear_modules=pre_linear, post_linear_modules=post_linear) + # assert net == [nn.BatchNorm1d(6), nn.Linear(6, 8), nn.LayerNorm(8), nn.ReLU() + # nn.BatchNorm1d(6), nn.Linear(8, 12), nn.LayerNorm(12), nn.ReLU(), + # nn.BatchNorm1d(12), nn.Linear(12, 2)] # Last layer does not have post_linear + assert len(net) == 10 + assert isinstance(net[0], nn.BatchNorm1d) + assert net[0].num_features == 6 + assert isinstance(net[1], nn.Linear) + assert isinstance(net[2], nn.LayerNorm) + assert isinstance(net[3], nn.ReLU) + assert isinstance(net[4], nn.BatchNorm1d) + assert isinstance(net[5], nn.Linear) + assert net[5].in_features == 8 + assert net[5].out_features == 12 + assert isinstance(net[6], nn.LayerNorm) + assert isinstance(net[7], nn.ReLU) + assert isinstance(net[8], nn.BatchNorm1d) + assert isinstance(net[-1], nn.Linear) + assert net[-1].in_features == 12 + assert net[-1].out_features == 2 From bd3c0c653068a6af1993df7be1a12acfb4be0127 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 26 Jul 2024 14:57:55 +0200 Subject: [PATCH 29/30] Fix loading of optimizer with older DQN models (#1978) --- docs/misc/changelog.rst | 11 ++++++++++- stable_baselines3/common/base_class.py | 27 ++++++++++++++++++++++++-- stable_baselines3/version.txt | 2 +- tests/test_save_load.py | 14 ++++++++++++- 4 files changed, 49 insertions(+), 5 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 31ff99d09..37a035478 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,9 +3,16 @@ Changelog ========== -Release 2.4.0a6 (WIP) +Release 2.4.0a7 (WIP) -------------------------- +.. note:: + + DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about + truncation of optimizer state when loaded with SB3 >= 2.4.0. + To suppress the warning, simply save the model again. + You can find more info in `PR #1963 `_ + Breaking Changes: ^^^^^^^^^^^^^^^^^ @@ -28,9 +35,11 @@ Bug Fixes: `RL Zoo`_ ^^^^^^^^^ +- Updated defaults hyperparameters for TQC/SAC for Swimmer-v4 (decrease gamma for more consistent results) `SBX`_ (SB3 + Jax) ^^^^^^^^^^^^^^^^^^ +- Added CNN support for DQN Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index b2c967405..e43955f94 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -742,13 +742,13 @@ def load( # noqa: C901 # put state_dicts back in place model.set_parameters(params, exact_match=True, device=device) except RuntimeError as e: - # Patch to load Policy saved using SB3 < 1.7.0 + # Patch to load policies saved using SB3 < 1.7.0 # the error is probably due to old policy being loaded # See https://github.com/DLR-RM/stable-baselines3/issues/1233 if "pi_features_extractor" in str(e) and "Missing key(s) in state_dict" in str(e): model.set_parameters(params, exact_match=False, device=device) warnings.warn( - "You are probably loading a model saved with SB3 < 1.7.0, " + "You are probably loading a A2C/PPO model saved with SB3 < 1.7.0, " "we deactivated exact_match so you can save the model " "again to avoid issues in the future " "(see https://github.com/DLR-RM/stable-baselines3/issues/1233 for more info). " @@ -757,6 +757,29 @@ def load( # noqa: C901 ) else: raise e + except ValueError as e: + # Patch to load DQN policies saved using SB3 < 2.4.0 + # The target network params are no longer in the optimizer + # See https://github.com/DLR-RM/stable-baselines3/pull/1963 + saved_optim_params = params["policy.optimizer"]["param_groups"][0]["params"] # type: ignore[index] + n_params_saved = len(saved_optim_params) + n_params = len(model.policy.optimizer.param_groups[0]["params"]) + if n_params_saved == 2 * n_params: + # Truncate to include only online network params + params["policy.optimizer"]["param_groups"][0]["params"] = saved_optim_params[:n_params] # type: ignore[index] + + model.set_parameters(params, exact_match=True, device=device) + warnings.warn( + "You are probably loading a DQN model saved with SB3 < 2.4.0, " + "we truncated the optimizer state so you can save the model " + "again to avoid issues in the future " + "(see https://github.com/DLR-RM/stable-baselines3/pull/1963 for more info). " + f"Original error: {e} \n" + "Note: the model should still work fine, this only a warning." + ) + else: + raise e + # put other pytorch variables back in place if pytorch_variables is not None: for name in pytorch_variables: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 464a5c4dc..f5230e413 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a6 +2.4.0a7 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 5dc6ca7bf..962088246 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -340,7 +340,7 @@ def test_save_load_env_cnn(tmp_path, model_class): # clear file from os os.remove(tmp_path / "test_save.zip") - # Check we can load models saved with SB3 < 1.7.0 + # Check we can load A2C/PPO models saved with SB3 < 1.7.0 if model_class == A2C: del model.policy.pi_features_extractor model.save(tmp_path / "test_save") @@ -809,3 +809,15 @@ def test_save_load_net_arch_none(tmp_path): # None has been replaced by the default net arch assert model.policy.net_arch is not None os.remove(tmp_path / "ppo.zip") + + +def test_save_load_no_target_params(tmp_path): + # Check we can load DQN models saved with SB3 < 2.4.0 + model = DQN("MlpPolicy", "CartPole-v1", buffer_size=10000, learning_starts=4) + env = model.get_env() + # Include target net params + model.policy.optimizer = th.optim.Adam(model.policy.parameters(), lr=0.001) + model.save(tmp_path / "test_save") + with pytest.warns(UserWarning): + DQN.load(str(tmp_path / "test_save.zip"), env=env).learn(20) + os.remove(tmp_path / "test_save.zip") From 6ad6fa55b6e38c8456dd333f71fe45373f66fe90 Mon Sep 17 00:00:00 2001 From: Chris Schindlbeck Date: Mon, 29 Jul 2024 10:44:23 +0200 Subject: [PATCH 30/30] Fix various typos (#1981) --- CODE_OF_CONDUCT.md | 2 +- stable_baselines3/common/on_policy_algorithm.py | 2 +- stable_baselines3/her/her_replay_buffer.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 137c95744..0ca033815 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -5,7 +5,7 @@ We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender -identity and expression, level of experience, education, socio-economic status, +identity and expression, level of experience, education, socioeconomic status, nationality, personal appearance, race, religion, or sexual identity and orientation. diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 1ba36d5f0..262453721 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -208,7 +208,7 @@ def collect_rollouts( # Reshape in case of discrete action actions = actions.reshape(-1, 1) - # Handle timeout by bootstraping with value function + # Handle timeout by bootstrapping with value function # see GitHub issue #633 for idx, done in enumerate(dones): if ( diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index 579c6ebf1..20214e72c 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -396,7 +396,7 @@ def truncate_last_trajectory(self) -> None: "If you are in the same episode as when the replay buffer was saved,\n" "you should use `truncate_last_trajectory=False` to avoid that issue." ) - # only consider epsiodes that are not finished + # only consider episodes that are not finished for env_idx in np.where(self._current_ep_start != self.pos)[0]: # set done = True for last episodes self.dones[self.pos - 1, env_idx] = True