diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b1078cd28..822e0cb3f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,9 +5,9 @@ name: CI on: push: - branches: [ master ] + branches: [master] pull_request: - branches: [ master ] + branches: [master] jobs: build: @@ -23,38 +23,40 @@ jobs: python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - # cpu version of pytorch - pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + # cpu version of pytorch + pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu - # Install Atari Roms - pip install autorom - wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 - base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz - AutoROM --accept-license --source-file Roms.tar.gz + # Install Atari Roms + pip install autorom + wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 + base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz + AutoROM --accept-license --source-file Roms.tar.gz - pip install .[extra_no_roms,tests,docs] - # Use headless version - pip install opencv-python-headless - - name: Lint with ruff - run: | - make lint - - name: Build the doc - run: | - make doc - - name: Check codestyle - run: | - make check-codestyle - - name: Type check - run: | - make type - - name: Test with pytest - run: | - make pytest + pip install .[extra_no_roms,tests,docs] + # Use headless version + pip install opencv-python-headless + - name: Lint with ruff + run: | + make lint + - name: Build the doc + run: | + make doc + - name: Check codestyle + run: | + make check-codestyle + - name: Type check + run: | + make type + # Do not run for python 3.8 (mypy internal error) + if: matrix.python-version != '3.8' + - name: Test with pytest + run: | + make pytest diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 137c95744..0ca033815 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -5,7 +5,7 @@ We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender -identity and expression, level of experience, education, socio-economic status, +identity and expression, level of experience, education, socioeconomic status, nationality, personal appearance, race, religion, or sexual identity and orientation. diff --git a/Makefile b/Makefile index fe9f6ae2e..d9b922515 100644 --- a/Makefile +++ b/Makefile @@ -18,19 +18,19 @@ type: mypy lint: # stop the build if there are Python syntax errors or undefined names # see https://www.flake8rules.com/ - ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source + ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full # exit-zero treats all errors as warnings. - ruff ${LINT_PATHS} --exit-zero + ruff check ${LINT_PATHS} --exit-zero --output-format=concise format: # Sort imports - ruff --select I ${LINT_PATHS} --fix + ruff check --select I ${LINT_PATHS} --fix # Reformat using black black ${LINT_PATHS} check-codestyle: # Sort imports - ruff --select I ${LINT_PATHS} + ruff check --select I ${LINT_PATHS} # Reformat using black black --check ${LINT_PATHS} diff --git a/README.md b/README.md index 4f427087b..78592bae8 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ Documentation is available online: [https://sb3-contrib.readthedocs.io/](https:/ ## Stable-Baselines Jax (SBX) -[Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax. +[Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax, with recent algorithms like DroQ or CrossQ. It provides a minimal number of features compared to SB3 but can be much faster (up to 20x times!): https://twitter.com/araffin2/status/1590714558628253698 @@ -127,7 +127,7 @@ import gymnasium as gym from stable_baselines3 import PPO -env = gym.make("CartPole-v1") +env = gym.make("CartPole-v1", render_mode="human") model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10_000) @@ -192,7 +192,7 @@ All the following examples can be executed online using Google Colab notebooks: 1: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) GitHub repository. Actions `gym.spaces`: - * `Box`: A N-dimensional box that containes every point in the action space. + * `Box`: A N-dimensional box that contains every point in the action space. * `Discrete`: A list of possible actions, where each timestep only one of the actions can be used. * `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used. * `MultiBinary`: A list of possible actions, where each timestep any of the actions can be used in any combination. diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 33ac3ba46..d5e7ae1d2 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -43,7 +43,8 @@ Actions ``gym.spaces``: .. note:: - More algorithms (like QR-DQN or TQC) are implemented in our :ref:`contrib repo `. + More algorithms (like QR-DQN or TQC) are implemented in our :ref:`contrib repo ` + and in our :ref:`SBX (SB3 + Jax) repo ` (DroQ, CrossQ, ...). .. note:: diff --git a/docs/guide/callbacks.rst b/docs/guide/callbacks.rst index 239966a6f..472f42114 100644 --- a/docs/guide/callbacks.rst +++ b/docs/guide/callbacks.rst @@ -29,24 +29,25 @@ You can find two examples of custom callbacks in the documentation: one for savi :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ - def __init__(self, verbose=0): + def __init__(self, verbose: int = 0): super().__init__(verbose) # Those variables will be accessible in the callback # (they are defined in the base class) # The RL model # self.model = None # type: BaseAlgorithm # An alias for self.model.get_env(), the environment used for training - # self.training_env = None # type: Union[gym.Env, VecEnv, None] + # self.training_env # type: VecEnv # Number of time the callback was called # self.n_calls = 0 # type: int + # num_timesteps = n_envs * n times env.step() was called # self.num_timesteps = 0 # type: int # local and global variables - # self.locals = None # type: Dict[str, Any] - # self.globals = None # type: Dict[str, Any] + # self.locals = {} # type: Dict[str, Any] + # self.globals = {} # type: Dict[str, Any] # The logger object, used to report things in the terminal - # self.logger = None # stable_baselines3.common.logger - # # Sometimes, for event callback, it is useful - # # to have access to the parent object + # self.logger # type: stable_baselines3.common.logger.Logger + # Sometimes, for event callback, it is useful + # to have access to the parent object # self.parent = None # type: Optional[BaseCallback] def _on_training_start(self) -> None: diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index a4729bfb3..32158172b 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -128,7 +128,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments :param env_id: the environment ID :param num_env: the number of environments you wish to have in subprocesses - :param seed: the inital seed for RNG + :param seed: the initial seed for RNG :param rank: index of the subprocess """ def _init(): @@ -179,9 +179,9 @@ Multiprocessing with off-policy algorithms vec_env = make_vec_env("Pendulum-v0", n_envs=4, seed=0) - # We collect 4 transitions per call to `ènv.step()` - # and performs 2 gradient steps per call to `ènv.step()` - # if gradient_steps=-1, then we would do 4 gradients steps per call to `ènv.step()` + # We collect 4 transitions per call to `env.step()` + # and performs 2 gradient steps per call to `env.step()` + # if gradient_steps=-1, then we would do 4 gradients steps per call to `env.step()` model = SAC("MlpPolicy", vec_env, train_freq=1, gradient_steps=2, verbose=1) model.learn(total_timesteps=10_000) @@ -436,7 +436,7 @@ will compute a running average and standard deviation of input features (it can log_dir = "/tmp/" model.save(log_dir + "ppo_halfcheetah") stats_path = os.path.join(log_dir, "vec_normalize.pkl") - env.save(stats_path) + vec_env.save(stats_path) # To demonstrate loading del model, vec_env diff --git a/docs/guide/export.rst b/docs/guide/export.rst index cccf30014..88a02fe8b 100644 --- a/docs/guide/export.rst +++ b/docs/guide/export.rst @@ -31,53 +31,52 @@ to do inference in another framework. Export to ONNX ----------------- -As of June 2021, ONNX format `doesn't support `_ exporting models that use the ``broadcast_tensors`` functionality of pytorch. So in order to export the trained stable-baseline3 models in the ONNX format, we need to first remove the layers that use broadcasting. This can be done by creating a class that removes the unsupported layers. -The following examples are for ``MlpPolicy`` only, and are general examples. Note that you have to preprocess the observation the same way stable-baselines3 agent does (see ``common.preprocessing.preprocess_obs``). +If you are using PyTorch 2.0+ and ONNX Opset 14+, you can easily export SB3 policies using the following code: -For PPO, assuming a shared feature extractor. .. warning:: - The following example is for continuous actions only. - When using discrete or binary actions, you must do some `post-processing `_ - to obtain the action (e.g., convert action logits to action). + The following returns normalized actions and doesn't include the `post-processing `_ step that is done with continuous actions + (clip or unscale the action to the correct space). .. code-block:: python import torch as th + from typing import Tuple from stable_baselines3 import PPO + from stable_baselines3.common.policies import BasePolicy - class OnnxablePolicy(th.nn.Module): - def __init__(self, extractor, action_net, value_net): + class OnnxableSB3Policy(th.nn.Module): + def __init__(self, policy: BasePolicy): super().__init__() - self.extractor = extractor - self.action_net = action_net - self.value_net = value_net + self.policy = policy - def forward(self, observation): - # NOTE: You may have to process (normalize) observation in the correct - # way before using this. See `common.preprocessing.preprocess_obs` - action_hidden, value_hidden = self.extractor(observation) - return self.action_net(action_hidden), self.value_net(value_hidden) + def forward(self, observation: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + # NOTE: Preprocessing is included, but postprocessing + # (clipping/inscaling actions) is not, + # If needed, you also need to transpose the images so that they are channel first + # use deterministic=False if you want to export the stochastic policy + # policy() returns `actions, values, log_prob` for PPO + return self.policy(observation, deterministic=True) # Example: model = PPO("MlpPolicy", "Pendulum-v1") + PPO("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel") model = PPO.load("PathToTrainedModel.zip", device="cpu") - onnxable_model = OnnxablePolicy( - model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net - ) + + onnx_policy = OnnxableSB3Policy(model.policy) observation_size = model.observation_space.shape dummy_input = th.randn(1, *observation_size) th.onnx.export( - onnxable_model, + onnx_policy, dummy_input, "my_ppo_model.onnx", - opset_version=9, + opset_version=17, input_names=["input"], ) @@ -93,7 +92,13 @@ For PPO, assuming a shared feature extractor. observation = np.zeros((1, *observation_size)).astype(np.float32) ort_sess = ort.InferenceSession(onnx_path) - action, value = ort_sess.run(None, {"input": observation}) + actions, values, log_prob = ort_sess.run(None, {"input": observation}) + + print(actions, values, log_prob) + + # Check that the predictions are the same + with th.no_grad(): + print(model.policy(th.as_tensor(observation), deterministic=True)) For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies. @@ -108,23 +113,16 @@ For SAC the procedure is similar. The example shown only exports the actor netwo class OnnxablePolicy(th.nn.Module): def __init__(self, actor: th.nn.Module): super().__init__() - # Removing the flatten layer because it can't be onnxed - self.actor = th.nn.Sequential( - actor.latent_pi, - actor.mu, - # For gSDE - # th.nn.Hardtanh(min_val=-actor.clip_mean, max_val=actor.clip_mean), - # Squash the output - th.nn.Tanh(), - ) + self.actor = actor def forward(self, observation: th.Tensor) -> th.Tensor: - # NOTE: You may have to process (normalize) observation in the correct - # way before using this. See `common.preprocessing.preprocess_obs` - return self.actor(observation) + # NOTE: You may have to postprocess (unnormalize) actions + # to the correct bounds (see commented code below) + return self.actor(observation, deterministic=True) # Example: model = SAC("MlpPolicy", "Pendulum-v1") + SAC("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel.zip") model = SAC.load("PathToTrainedModel.zip", device="cpu") onnxable_model = OnnxablePolicy(model.policy.actor) @@ -134,7 +132,7 @@ For SAC the procedure is similar. The example shown only exports the actor netwo onnxable_model, dummy_input, "my_sac_actor.onnx", - opset_version=9, + opset_version=17, input_names=["input"], ) @@ -147,10 +145,23 @@ For SAC the procedure is similar. The example shown only exports the actor netwo observation = np.zeros((1, *observation_size)).astype(np.float32) ort_sess = ort.InferenceSession(onnx_path) - action = ort_sess.run(None, {"input": observation}) + scaled_action = ort_sess.run(None, {"input": observation})[0] + + print(scaled_action) + + # Post-process: rescale to correct space + # Rescale the action from [-1, 1] to [low, high] + # low, high = model.action_space.low, model.action_space.high + # post_processed_action = low + (0.5 * (scaled_action + 1.0) * (high - low)) + + # Check that the predictions are the same + with th.no_grad(): + print(model.actor(th.as_tensor(observation), deterministic=True)) + + +For more discussion around the topic, please refer to `GH#383 `_ and `GH#1349 `_. -For more discussion around the topic refer to this `issue. `_ Trace/Export to C++ ------------------- diff --git a/docs/guide/integrations.rst b/docs/guide/integrations.rst index 14573cdec..9f864a2e0 100644 --- a/docs/guide/integrations.rst +++ b/docs/guide/integrations.rst @@ -70,8 +70,10 @@ Installation .. code-block:: bash + # Download model and save it into the logs/ folder - python -m rl_zoo3.load_from_hub --algo a2c --env LunarLander-v2 -orga sb3 -f logs/ + # Only use TRUST_REMOTE_CODE=True with HF models that can be trusted (here the SB3 organization) + TRUST_REMOTE_CODE=True python -m rl_zoo3.load_from_hub --algo a2c --env LunarLander-v2 -orga sb3 -f logs/ # Test the agent python -m rl_zoo3.enjoy --algo a2c --env LunarLander-v2 -f logs/ # Push model, config and hyperparameters to the hub @@ -86,12 +88,19 @@ For instance ``sb3/demo-hf-CartPole-v1``: .. code-block:: python + import os + import gymnasium as gym from huggingface_sb3 import load_from_hub from stable_baselines3 import PPO from stable_baselines3.common.evaluation import evaluate_policy + + # Allow the use of `pickle.load()` when downloading model from the hub + # Please make sure that the organization from which you download can be trusted + os.environ["TRUST_REMOTE_CODE"] = "True" + # Retrieve the model from the hub ## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name}) ## filename = name of the model zip file from the repository diff --git a/docs/guide/rl_tips.rst b/docs/guide/rl_tips.rst index ce6f43e55..3acd1b433 100644 --- a/docs/guide/rl_tips.rst +++ b/docs/guide/rl_tips.rst @@ -4,7 +4,7 @@ Reinforcement Learning Tips and Tricks ====================================== -The aim of this section is to help you do reinforcement learning experiments. +The aim of this section is to help you run reinforcement learning experiments. It covers general advice about RL (where to start, which algorithm to choose, how to evaluate an algorithm, ...), as well as tips and tricks when using a custom environment or implementing an RL algorithm. @@ -14,6 +14,11 @@ as well as tips and tricks when using a custom environment or implementing an RL this section in more details. You can also find the `slides here `_. +.. note:: + + We also have a `video on Designing and Running Real-World RL Experiments `_, slides `can be found online `_. + + General advice when using Reinforcement Learning ================================================ @@ -103,19 +108,19 @@ and this `issue `_ by Cé Which algorithm should I use? ============================= -There is no silver bullet in RL, depending on your needs and problem, you may choose one or the other. +There is no silver bullet in RL, you can choose one or the other depending on your needs and problems. The first distinction comes from your action space, i.e., do you have discrete (e.g. LEFT, RIGHT, ...) or continuous actions (ex: go to a certain speed)? -Some algorithms are only tailored for one or the other domain: ``DQN`` only supports discrete actions, where ``SAC`` is restricted to continuous actions. +Some algorithms are only tailored for one or the other domain: ``DQN`` supports only discrete actions, while ``SAC`` is restricted to continuous actions. -The second difference that will help you choose is whether you can parallelize your training or not. +The second difference that will help you decide is whether you can parallelize your training or not. If what matters is the wall clock training time, then you should lean towards ``A2C`` and its derivatives (PPO, ...). Take a look at the `Vectorized Environments `_ to learn more about training with multiple workers. -To accelerate training, you can also take a look at `SBX`_, which is SB3 + Jax, it has fewer features than SB3 but can be up to 20x faster than SB3 PyTorch thanks to JIT compilation of the gradient update. +To accelerate training, you can also take a look at `SBX`_, which is SB3 + Jax, it has less features than SB3 but can be up to 20x faster than SB3 PyTorch thanks to JIT compilation of the gradient update. -In sparse reward settings, we either recommend to use dedicated methods like HER (see below) or population-based algorithms like ARS (available in our :ref:`contrib repo `). +In sparse reward settings, we either recommend using either dedicated methods like HER (see below) or population-based algorithms like ARS (available in our :ref:`contrib repo `). To sum it up: @@ -146,7 +151,7 @@ Continuous Actions Continuous Actions - Single Process ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Current State Of The Art (SOTA) algorithms are ``SAC``, ``TD3`` and ``TQC`` (available in our :ref:`contrib repo `). +Current State Of The Art (SOTA) algorithms are ``SAC``, ``TD3``, ``CrossQ`` and ``TQC`` (available in our :ref:`contrib repo ` and :ref:`SBX (SB3 + Jax) repo `). Please use the hyperparameters in the `RL zoo `_ for best results. If you want an extremely sample-efficient algorithm, we recommend using the `DroQ configuration `_ in `SBX`_ (it does many gradient steps per step in the environment). @@ -155,8 +160,7 @@ If you want an extremely sample-efficient algorithm, we recommend using the `Dro Continuous Actions - Multiprocessed ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Take a look at ``PPO``, ``TRPO`` (available in our :ref:`contrib repo `) or ``A2C``. Again, don't forget to take the hyperparameters from the `RL zoo `_ -for continuous actions problems (cf *Bullet* envs). +Take a look at ``PPO``, ``TRPO`` (available in our :ref:`contrib repo `) or ``A2C``. Again, don't forget to take the hyperparameters from the `RL zoo `_ for continuous actions problems (cf *Bullet* envs). .. note:: @@ -181,26 +185,23 @@ Tips and Tricks when creating a custom environment ================================================== If you want to learn about how to create a custom environment, we recommend you read this `page `_. -We also provide a `colab notebook `_ for -a concrete example of creating a custom gym environment. +We also provide a `colab notebook `_ for a concrete example of creating a custom gym environment. Some basic advice: -- always normalize your observation space when you can, i.e., when you know the boundaries -- normalize your action space and make it symmetric when continuous (cf potential issue below) A good practice is to rescale your actions to lie in [-1, 1]. This does not limit you as you can easily rescale the action inside the environment -- start with shaped reward (i.e. informative reward) and simplified version of your problem -- debug with random actions to check that your environment works and follows the gym interface: +- always normalize your observation space if you can, i.e. if you know the boundaries +- normalize your action space and make it symmetric if it is continuous (see potential problem below) A good practice is to rescale your actions so that they lie in [-1, 1]. This does not limit you, as you can easily rescale the action within the environment. +- start with a shaped reward (i.e. informative reward) and a simplified version of your problem +- debug with random actions to check if your environment works and follows the gym interface (with ``check_env``, see below) -Two important things to keep in mind when creating a custom environment is to avoid breaking Markov assumption +Two important things to keep in mind when creating a custom environment are avoiding breaking the Markov assumption and properly handle termination due to a timeout (maximum number of steps in an episode). -For instance, if there is some time delay between action and observation (e.g. due to wifi communication), you should give a history of observations -as input. +For example, if there is a time delay between action and observation (e.g. due to wifi communication), you should provide a history of observations as input. Termination due to timeout (max number of steps per episode) needs to be handled separately. You should return ``truncated = True``. If you are using the gym ``TimeLimit`` wrapper, this will be done automatically. -You can read `Time Limit in RL `_ or take a look at the `RL Tips and Tricks video `_ -for more details. +You can read `Time Limit in RL `_, take a look at the `Designing and Running Real-World RL Experiments video `_ or `RL Tips and Tricks video `_ for more details. We provide a helper to check that your environment runs without error: @@ -234,7 +235,7 @@ If you want to quickly try a random agent on your environment, you can also do: Most reinforcement learning algorithms rely on a Gaussian distribution (initially centered at 0 with std 1) for continuous actions. So, if you forget to normalize the action space when using a custom environment, -this can harm learning and be difficult to debug (cf attached image and `issue #473 `_). +this can harm learning and can be difficult to debug (cf attached image and `issue #473 `_). .. figure:: ../_static/img/mistake.png @@ -252,6 +253,12 @@ A better solution would be to use a squashing function (cf ``SAC``) or a Beta di Tips and Tricks when implementing an RL algorithm ================================================= +.. note:: + + We have a `video on YouTube about reliable RL `_ that covers + this section in more details. You can also find the `slides online `_. + + When you try to reproduce a RL paper by implementing the algorithm, the `nuts and bolts of RL research `_ by John Schulman are quite useful (`video `_). @@ -282,4 +289,4 @@ in RL with discrete actions: 3. Pong (one of the easiest Atari game) 4. other Atari games (e.g. Breakout) -.. _SBX: https://github.com/araffin/sbx \ No newline at end of file +.. _SBX: https://github.com/araffin/sbx diff --git a/docs/guide/sbx.rst b/docs/guide/sbx.rst index 52b4348bc..ed5369ea4 100644 --- a/docs/guide/sbx.rst +++ b/docs/guide/sbx.rst @@ -17,6 +17,7 @@ Implemented algorithms: - Deep Q Network (DQN) - Twin Delayed DDPG (TD3) - Deep Deterministic Policy Gradient (DDPG) +- Batch Normalization in Deep Reinforcement Learning (CrossQ) As SBX follows SB3 API, it is also compatible with the `RL Zoo `_. @@ -29,16 +30,17 @@ For that you will need to create two files: import rl_zoo3 import rl_zoo3.train from rl_zoo3.train import train - - from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ + from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ rl_zoo3.ALGOS["ddpg"] = DDPG rl_zoo3.ALGOS["dqn"] = DQN - rl_zoo3.ALGOS["droq"] = DroQ + # See SBX readme to use DroQ configuration + # rl_zoo3.ALGOS["droq"] = DroQ rl_zoo3.ALGOS["sac"] = SAC rl_zoo3.ALGOS["ppo"] = PPO rl_zoo3.ALGOS["td3"] = TD3 rl_zoo3.ALGOS["tqc"] = TQC + rl_zoo3.ALGOS["crossq"] = CrossQ rl_zoo3.train.ALGOS = rl_zoo3.ALGOS rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS @@ -56,16 +58,17 @@ Then you can call ``python train_sbx.py --algo sac --env Pendulum-v1`` and use t import rl_zoo3 import rl_zoo3.enjoy from rl_zoo3.enjoy import enjoy - - from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ + from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ rl_zoo3.ALGOS["ddpg"] = DDPG rl_zoo3.ALGOS["dqn"] = DQN - rl_zoo3.ALGOS["droq"] = DroQ + # See SBX readme to use DroQ configuration + # rl_zoo3.ALGOS["droq"] = DroQ rl_zoo3.ALGOS["sac"] = SAC rl_zoo3.ALGOS["ppo"] = PPO rl_zoo3.ALGOS["td3"] = TD3 rl_zoo3.ALGOS["tqc"] = TQC + rl_zoo3.ALGOS["crossq"] = CrossQ rl_zoo3.enjoy.ALGOS = rl_zoo3.ALGOS rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS diff --git a/docs/guide/tensorboard.rst b/docs/guide/tensorboard.rst index 4ef1b496a..ba62e5e6b 100644 --- a/docs/guide/tensorboard.rst +++ b/docs/guide/tensorboard.rst @@ -192,6 +192,7 @@ Here is an example of how to render an episode and log the resulting video to Te import gymnasium as gym import torch as th + import numpy as np from stable_baselines3 import A2C from stable_baselines3.common.callbacks import BaseCallback @@ -226,6 +227,9 @@ Here is an example of how to render an episode and log the resulting video to Te :param _locals: A dictionary containing all local variables of the callback's scope :param _globals: A dictionary containing all global variables of the callback's scope """ + # We expect `render()` to return a uint8 array with values in [0, 255] or a float array + # with values in [0, 1], as described in + # https://pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_video screen = self._eval_env.render(mode="rgb_array") # PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image convention screens.append(screen.transpose(2, 0, 1)) @@ -239,7 +243,7 @@ Here is an example of how to render an episode and log the resulting video to Te ) self.logger.record( "trajectory/video", - Video(th.ByteTensor([screens]), fps=40), + Video(th.from_numpy(np.asarray([screens])), fps=40), exclude=("stdout", "log", "json", "csv"), ) return True diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst index 792fedecb..c04001c7c 100644 --- a/docs/guide/vec_envs.rst +++ b/docs/guide/vec_envs.rst @@ -96,6 +96,90 @@ SB3 VecEnv API is actually close to Gym 0.21 API but differs to Gym 0.26+ API: ``vec_env.env_method("method_name", args1, args2, kwargs1=kwargs1)`` and ``vec_env.set_attr("attribute_name", new_value)``. +Modifying Vectorized Environments Attributes +-------------------------------------------- + +If you plan to `modify the attributes of an environment `_ while it is used (e.g., modifying an attribute specifying the task carried out for a portion of training when doing multi-task learning, or +a parameter of the environment dynamics), you must expose a setter method. +In fact, directly accessing the environment attribute in the callback can lead to unexpected behavior because environments can be wrapped (using gym or VecEnv wrappers, the ``Monitor`` wrapper being one example). + +Consider the following example for a custom env: + +.. code-block:: python + + import gymnasium as gym + from gymnasium import spaces + + from stable_baselines3.common.env_util import make_vec_env + + + class MyMultiTaskEnv(gym.Env): + + def __init__(self): + super().__init__() + """ + A state and action space for robotic locomotion. + The multi-task twist is that the policy would need to adapt to different terrains, each with its own + friction coefficient, mu. + The friction coefficient is the only parameter that changes between tasks. + mu is a scalar between 0 and 1, and during training a callback is used to update mu. + """ + ... + + def step(self, action): + # Do something, depending on the action and current value of mu the next state is computed + return self._get_obs(), reward, done, truncated, info + + def set_mu(self, new_mu: float) -> None: + # Note: this value should be used only at the next reset + self.mu = new_mu + + # Example of wrapped env + # env is of type >>>> + env = gym.make("CartPole-v1") + # To access the base env, without wrapper, you should use `.unwrapped` + # or env.get_wrapper_attr("gravity") to include wrappers + env.unwrapped.gravity + # SB3 uses VecEnv for training, where `env.unwrapped.x = new_value` cannot be used to set an attribute + # therefore, you should expose a setter like `set_mu` to properly set an attribute + vec_env = make_vec_env(MyMultiTaskEnv) + # Print current mu value + # Note: you should use vec_env.env_method("get_wrapper_attr", "mu") in Gymnasium v1.0 + print(vec_env.env_method("get_wrapper_attr", "mu")) + # Change `mu` attribute via the setter + vec_env.env_method("set_mu", "mu", 0.1) + + +In this example ``env.mu`` cannot be accessed/changed directly because it is wrapped in a ``VecEnv`` and because it could be wrapped with other wrappers (see `GH#1573 `_ for a longer explanation). +Instead, the callback should use the ``set_mu`` method via the ``env_method`` method for Vectorized Environments. + +.. code-block:: python + + from itertools import cycle + + class ChangeMuCallback(BaseCallback): + """ + This callback changes the value of mu during training looping + through a list of values until training is aborted. + The environment is implemented so that the impact of changing + the value of mu mid-episode is visible only after the episode is over + and the reset method has been called. + """" + def __init__(self): + super().__init__() + # An iterator that contains the different of the friction coefficient + self.mus = cycle([0.1, 0.2, 0.5, 0.13, 0.9]) + + def _on_step(self): + # Note: in practice, you should not change this value at every step + # but rather depending on some events/metrics like agent performance/episode termination + # both accessible via the `self.logger` or `self.locals` variables + self.training_env.env_method("set_mu", next(self.mus)) + +This callback can then be used to safely modify environment attributes during training since +it calls the environment setter method. + + Vectorized Environments Wrappers -------------------------------- diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 6c5d6cc48..e5383c0ed 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,181 @@ Changelog ========== +Release 2.4.0a7 (WIP) +-------------------------- + +.. note:: + + DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about + truncation of optimizer state when loaded with SB3 >= 2.4.0. + To suppress the warning, simply save the model again. + You can find more info in `PR #1963 `_ + +Breaking Changes: +^^^^^^^^^^^^^^^^^ + +New Features: +^^^^^^^^^^^^^ +- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ) + +Bug Fixes: +^^^^^^^^^^ +- Fixed memory leak when loading learner from storage, ``set_parameters()`` does not try to load the object data anymore + and only loads the PyTorch parameters (@peteole) +- Cast type in compute gae method to avoid error when using torch compile (@amjames) +- ``CallbackList`` now sets the ``.parent`` attribute of child callbacks to its own ``.parent``. (will-maclean) +- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122) +- Set requirement numpy<2.0 until PyTorch is compatible (https://github.com/pytorch/pytorch/issues/107302) +- Updated DQN optimizer input to only include q_network parameters, removing the target_q_network ones (@corentinlger) + +`SB3-Contrib`_ +^^^^^^^^^^^^^^ + +`RL Zoo`_ +^^^^^^^^^ +- Updated defaults hyperparameters for TQC/SAC for Swimmer-v4 (decrease gamma for more consistent results) + +`SBX`_ (SB3 + Jax) +^^^^^^^^^^^^^^^^^^ +- Added CNN support for DQN + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ +- Fixed various typos (@cschindlbeck) +- Remove unnecessary SDE noise resampling in PPO update (@brn-dev) +- Updated PyTorch version on CI to 2.3.1 + +Bug Fixes: +^^^^^^^^^^ + +Documentation: +^^^^^^^^^^^^^^ + +Release 2.3.2 (2024-04-27) +-------------------------- + +Bug Fixes: +^^^^^^^^^^ +- Reverted ``torch.load()`` to be called ``weights_only=False`` as it caused loading issue with old version of PyTorch. + + +Documentation: +^^^^^^^^^^^^^^ +- Added ER-MRL to the project page (@corentinlger) +- Updated Tensorboard Logging Videos documentation (@NickLucche) + + +Release 2.3.1 (2024-04-22) +-------------------------- + +Bug Fixes: +^^^^^^^^^^ +- Cast return value of learning rate schedule to float, to avoid issue when loading model because of ``weights_only=True`` (@markscsmith) + +Documentation: +^^^^^^^^^^^^^^ +- Updated SBX documentation (CrossQ and deprecated DroQ) +- Updated RL Tips and Tricks section + +Release 2.3.0 (2024-03-31) +-------------------------- + +**New defaults hyperparameters for DDPG, TD3 and DQN** + +.. warning:: + + Because of ``weights_only=True``, this release breaks loading of policies when using PyTorch 1.13. + Please upgrade to PyTorch >= 2.0 or upgrade SB3 version (we reverted the change in SB3 2.3.2) + + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- The defaults hyperparameters of ``TD3`` and ``DDPG`` have been changed to be more consistent with ``SAC`` + +.. code-block:: python + + # SB3 < 2.3.0 default hyperparameters + # model = TD3("MlpPolicy", env, train_freq=(1, "episode"), gradient_steps=-1, batch_size=100) + # SB3 >= 2.3.0: + model = TD3("MlpPolicy", env, train_freq=1, gradient_steps=1, batch_size=256) + +.. note:: + + Two inconsistencies remain: the default network architecture for ``TD3/DDPG`` is ``[400, 300]`` instead of ``[256, 256]`` for SAC (for backward compatibility reasons, see `report on the influence of the network size `_) and the default learning rate is 1e-3 instead of 3e-4 for SAC (for performance reasons, see `W&B report on the influence of the lr `_) + + + +- The default ``learning_starts`` parameter of ``DQN`` have been changed to be consistent with the other offpolicy algorithms + + +.. code-block:: python + + # SB3 < 2.3.0 default hyperparameters, 50_000 corresponded to Atari defaults hyperparameters + # model = DQN("MlpPolicy", env, learning_starts=50_000) + # SB3 >= 2.3.0: + model = DQN("MlpPolicy", env, learning_starts=100) + +- For safety, ``torch.load()`` is now called with ``weights_only=True`` when loading torch tensors, + policy ``load()`` still uses ``weights_only=False`` as gymnasium imports are required for it to work +- When using ``huggingface_sb3``, you will now need to set ``TRUST_REMOTE_CODE=True`` when downloading models from the hub, as ``pickle.load`` is not safe. + + +New Features: +^^^^^^^^^^^^^ +- Log success rate ``rollout/success_rate`` when available for on policy algorithms (@corentinlger) + +Bug Fixes: +^^^^^^^^^^ +- Fixed ``monitor_wrapper`` argument that was not passed to the parent class, and dones argument that wasn't passed to ``_update_into_buffer`` (@corentinlger) + +`SB3-Contrib`_ +^^^^^^^^^^^^^^ +- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to MaskablePPO +- Fixed ``train_freq`` type annotation for tqc and qrdqn (@Armandpl) +- Fixed ``sb3_contrib/common/maskable/*.py`` type annotations +- Fixed ``sb3_contrib/ppo_mask/ppo_mask.py`` type annotations +- Fixed ``sb3_contrib/common/vec_env/async_eval.py`` type annotations +- Add some additional notes about ``MaskablePPO`` (evaluation and multi-process) (@icheered) + + +`RL Zoo`_ +^^^^^^^^^ +- Updated defaults hyperparameters for TD3/DDPG to be more consistent with SAC +- Upgraded MuJoCo envs hyperparameters to v4 (pre-trained agents need to be updated) +- Added test dependencies to `setup.py` (@power-edge) +- Simplify dependencies of `requirements.txt` (remove duplicates from `setup.py`) + +`SBX`_ (SB3 + Jax) +^^^^^^^^^^^^^^^^^^ +- Added support for ``MultiDiscrete`` and ``MultiBinary`` action spaces to PPO +- Added support for large values for gradient_steps to SAC, TD3, and TQC +- Fix ``train()`` signature and update type hints +- Fix replay buffer device at load time +- Added flatten layer +- Added ``CrossQ`` + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ +- Updated black from v23 to v24 +- Updated ruff to >= v0.3.1 +- Updated env checker for (multi)discrete spaces with non-zero start. + +Documentation: +^^^^^^^^^^^^^^ +- Added a paragraph on modifying vectorized environment parameters via setters (@fracapuano) +- Updated callback code example +- Updated export to ONNX documentation, it is now much simpler to export SB3 models with newer ONNX Opset! +- Added video link to "Practical Tips for Reliable Reinforcement Learning" video +- Added ``render_mode="human"`` in the README example (@marekm4) +- Fixed docstring signature for sum_independent_dims (@stagoverflow) +- Updated docstring description for ``log_interval`` in the base class (@rushitnshah). + Release 2.2.1 (2023-11-17) -------------------------- **Support for options at reset, bug fixes and better error messages** @@ -248,7 +423,7 @@ Breaking Changes: ^^^^^^^^^^^^^^^^^ - Removed shared layers in ``mlp_extractor`` (@AlexPasqua) - Refactored ``StackedObservations`` (it now handles dict obs, ``StackedDictObservations`` was removed) -- You must now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()`` +- You must now explicitly pass a ``features_extractor`` parameter when calling ``extract_features()`` - Dropped offline sampling for ``HerReplayBuffer`` - As ``HerReplayBuffer`` was refactored to support multiprocessing, previous replay buffer are incompatible with this new version - ``HerReplayBuffer`` doesn't require a ``max_episode_length`` anymore @@ -380,7 +555,7 @@ Bug Fixes: Deprecations: ^^^^^^^^^^^^^ -- You should now explicitely pass a ``features_extractor`` parameter when calling ``extract_features()`` +- You should now explicitly pass a ``features_extractor`` parameter when calling ``extract_features()`` - Deprecated shared layers in ``MlpExtractor`` (@AlexPasqua) Others: @@ -591,7 +766,7 @@ Bug Fixes: - Fixed a bug in ``HumanOutputFormat``. Distinct keys truncated to the same prefix would overwrite each others value, resulting in only one being output. This now raises an error (this should only affect a small fraction of use cases with very long keys.) -- Routing all the ``nn.Module`` calls through implicit rather than explict forward as per pytorch guidelines (@manuel-delverme) +- Routing all the ``nn.Module`` calls through implicit rather than explicit forward as per pytorch guidelines (@manuel-delverme) - Fixed a bug in ``VecNormalize`` where error occurs when ``norm_obs`` is set to False for environment with dictionary observation (@buoyancy99) - Set default ``env`` argument to ``None`` in ``HerReplayBuffer.sample`` (@qgallouedec) - Fix ``batch_size`` typing in ``DQN`` (@qgallouedec) @@ -1491,7 +1666,7 @@ And all the contributors: @flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3 @tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio @diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray -@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @JadenTravnik @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn +@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @fracapuano @JadenTravnik @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn @ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro @amy12xx @juancroldan @benblack769 @bstee615 @c-rizz @skandermoalla @MihaiAnca13 @davidblom603 @ayeright @cyprienc @wkirgsn @AechPro @CUN-bjy @batu @IljaAvadiev @timokau @kachayev @cleversonahum @@ -1503,3 +1678,5 @@ And all the contributors: @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @ReHoss @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger +@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean +@brn-dev diff --git a/docs/misc/projects.rst b/docs/misc/projects.rst index 2b2e2405c..5f0c69710 100644 --- a/docs/misc/projects.rst +++ b/docs/misc/projects.rst @@ -239,3 +239,14 @@ Playing Pokemon Red with Reinforcement Learning. | Author: Peter Whidden | Github: https://github.com/PWhiddy/PokemonRedExperiments | Video: https://www.youtube.com/watch?v=DcYLT37ImBY + + +Evolving Reservoirs for Meta Reinforcement Learning +--------------------------------------------------- + +Meta-RL framework to optimize reservoir-like neural structures (special kind of RNNs), and integrate them to RL agents to improve their training. +It enables solving environments involving partial observability or locomotion (e.g MuJoCo), and optimizing reservoirs that can generalize to unseen tasks. + +| Authors: Corentin Léger, Gautier Hamon, Eleni Nisioti, Xavier Hinaut, Clément Moulin-Frier +| Github: https://github.com/corentinlger/ER-MRL +| Paper: https://arxiv.org/abs/2312.06695 diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index ace2fcc63..b5e667241 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -23,7 +23,7 @@ Notes - Original paper: https://arxiv.org/abs/1707.06347 - Clear explanation of PPO on Arxiv Insights channel: https://www.youtube.com/watch?v=5P7I-xPq8u8 -- OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/ +- OpenAI blog post: https://openai.com/research/openai-baselines-ppo - Spinning Up guide: https://spinningup.openai.com/en/latest/algorithms/ppo.html - 37 implementation details blog: https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ diff --git a/pyproject.toml b/pyproject.toml index 1195687f4..dd435a33e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,21 +3,23 @@ line-length = 127 # Assume Python 3.8 target-version = "py38" + +[tool.ruff.lint] # See https://beta.ruff.rs/docs/rules/ select = ["E", "F", "B", "UP", "C90", "RUF"] # B028: Ignore explicit stacklevel` # RUF013: Too many false positives (implicit optional) ignore = ["B028", "RUF013"] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] # Default implementation in abstract methods -"./stable_baselines3/common/callbacks.py"= ["B027"] -"./stable_baselines3/common/noise.py"= ["B027"] +"./stable_baselines3/common/callbacks.py" = ["B027"] +"./stable_baselines3/common/noise.py" = ["B027"] # ClassVar, implicit optional check not needed for tests -"./tests/*.py"= ["RUF012", "RUF013"] +"./tests/*.py" = ["RUF012", "RUF013"] -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 15 @@ -35,31 +37,35 @@ exclude = """(?x)( [tool.pytest.ini_options] # Deterministic ordering for tests; useful for pytest-xdist. -env = [ - "PYTHONHASHSEED=0" -] +env = ["PYTHONHASHSEED=0"] filterwarnings = [ # Tensorboard warnings "ignore::DeprecationWarning:tensorboard", # Gymnasium warnings "ignore::UserWarning:gymnasium", + # tqdm warning about rich being experimental + "ignore:rich is experimental", ] markers = [ - "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')" + "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')", ] [tool.coverage.run] disable_warnings = ["couldnt-parse"] branch = false omit = [ - "tests/*", - "setup.py", - # Require graphical interface - "stable_baselines3/common/results_plotter.py", - # Require ffmpeg - "stable_baselines3/common/vec_env/vec_video_recorder.py", + "tests/*", + "setup.py", + # Require graphical interface + "stable_baselines3/common/results_plotter.py", + # Require ffmpeg + "stable_baselines3/common/vec_env/vec_video_recorder.py", ] [tool.coverage.report] -exclude_lines = [ "pragma: no cover", "raise NotImplementedError()", "if typing.TYPE_CHECKING:"] +exclude_lines = [ + "pragma: no cover", + "raise NotImplementedError()", + "if typing.TYPE_CHECKING:", +] diff --git a/setup.py b/setup.py index 5e10ed66c..9d56dfd77 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ from stable_baselines3 import PPO -env = gymnasium.make("CartPole-v1") +env = gymnasium.make("CartPole-v1", render_mode="human") model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10_000) @@ -101,7 +101,7 @@ package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ "gymnasium>=0.28.1,<0.30", - "numpy>=1.20", + "numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302 "torch>=1.13", # For saving models "cloudpickle", @@ -120,9 +120,9 @@ # Type check "mypy", # Lint code and sort imports (flake8 and isort replacement) - "ruff>=0.0.288", + "ruff>=0.3.1", # Reformat - "black>=23.9.1,<24", + "black>=24.2.0,<25", ], "docs": [ "sphinx>=5,<8", diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 5e8759990..e43955f94 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -48,7 +48,7 @@ def maybe_make_env(env: Union[GymEnv, str], verbose: int) -> GymEnv: """If env is a string, make the environment; otherwise, return env. :param env: The environment to learn from. - :param verbose: Verbosity level: 0 for no output, 1 for indicating if envrironment is created + :param verbose: Verbosity level: 0 for no output, 1 for indicating if environment is created :return A Gym (vector) environment. """ if isinstance(env, str): @@ -523,7 +523,10 @@ def learn( :param total_timesteps: The total number of samples (env steps) to train on :param callback: callback(s) called at every step with state of the algorithm. - :param log_interval: The number of episodes before logging. + :param log_interval: for on-policy algos (e.g., PPO, A2C, ...) this is the number of + training iterations (i.e., log_interval * n_steps * n_envs timesteps) before logging; + for off-policy algos (e.g., TD3, SAC, ...) this is the number of episodes before + logging. :param tb_log_name: the name of the run for TensorBoard logging :param reset_num_timesteps: whether or not to reset the current timestep number (used in logging) :param progress_bar: Display a progress bar using tqdm and rich. @@ -589,7 +592,7 @@ def set_parameters( if isinstance(load_path_or_dict, dict): params = load_path_or_dict else: - _, params, _ = load_from_zip_file(load_path_or_dict, device=device) + _, params, _ = load_from_zip_file(load_path_or_dict, device=device, load_data=False) # Keep track which objects were updated. # `_get_torch_save_params` returns [params, other_pytorch_variables]. @@ -689,10 +692,9 @@ def load( # noqa: C901 if "device" in data["policy_kwargs"]: del data["policy_kwargs"]["device"] # backward compatibility, convert to new format - if "net_arch" in data["policy_kwargs"] and len(data["policy_kwargs"]["net_arch"]) > 0: - saved_net_arch = data["policy_kwargs"]["net_arch"] - if isinstance(saved_net_arch, list) and isinstance(saved_net_arch[0], dict): - data["policy_kwargs"]["net_arch"] = saved_net_arch[0] + saved_net_arch = data["policy_kwargs"].get("net_arch") + if saved_net_arch and isinstance(saved_net_arch, list) and isinstance(saved_net_arch[0], dict): + data["policy_kwargs"]["net_arch"] = saved_net_arch[0] if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]: raise ValueError( @@ -740,13 +742,13 @@ def load( # noqa: C901 # put state_dicts back in place model.set_parameters(params, exact_match=True, device=device) except RuntimeError as e: - # Patch to load Policy saved using SB3 < 1.7.0 + # Patch to load policies saved using SB3 < 1.7.0 # the error is probably due to old policy being loaded # See https://github.com/DLR-RM/stable-baselines3/issues/1233 if "pi_features_extractor" in str(e) and "Missing key(s) in state_dict" in str(e): model.set_parameters(params, exact_match=False, device=device) warnings.warn( - "You are probably loading a model saved with SB3 < 1.7.0, " + "You are probably loading a A2C/PPO model saved with SB3 < 1.7.0, " "we deactivated exact_match so you can save the model " "again to avoid issues in the future " "(see https://github.com/DLR-RM/stable-baselines3/issues/1233 for more info). " @@ -755,6 +757,29 @@ def load( # noqa: C901 ) else: raise e + except ValueError as e: + # Patch to load DQN policies saved using SB3 < 2.4.0 + # The target network params are no longer in the optimizer + # See https://github.com/DLR-RM/stable-baselines3/pull/1963 + saved_optim_params = params["policy.optimizer"]["param_groups"][0]["params"] # type: ignore[index] + n_params_saved = len(saved_optim_params) + n_params = len(model.policy.optimizer.param_groups[0]["params"]) + if n_params_saved == 2 * n_params: + # Truncate to include only online network params + params["policy.optimizer"]["param_groups"][0]["params"] = saved_optim_params[:n_params] # type: ignore[index] + + model.set_parameters(params, exact_match=True, device=device) + warnings.warn( + "You are probably loading a DQN model saved with SB3 < 2.4.0, " + "we truncated the optimizer state so you can save the model " + "again to avoid issues in the future " + "(see https://github.com/DLR-RM/stable-baselines3/pull/1963 for more info). " + f"Original error: {e} \n" + "Note: the model should still work fine, this only a warning." + ) + else: + raise e + # put other pytorch variables back in place if pytorch_variables is not None: for name in pytorch_variables: diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 306b43571..b2fc5a710 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -419,12 +419,12 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarra :param dones: if the last step was a terminal step (one bool for each env). """ # Convert to numpy - last_values = last_values.clone().cpu().numpy().flatten() + last_values = last_values.clone().cpu().numpy().flatten() # type: ignore[assignment] last_gae_lam = 0 for step in reversed(range(self.buffer_size)): if step == self.buffer_size - 1: - next_non_terminal = 1.0 - dones + next_non_terminal = 1.0 - dones.astype(np.float32) next_values = last_values else: next_non_terminal = 1.0 - self.episode_starts[step + 1] diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index 2898df8f4..c7841866b 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -204,6 +204,10 @@ def _init_callback(self) -> None: for callback in self.callbacks: callback.init_callback(self.model) + # Fix for https://github.com/DLR-RM/stable-baselines3/issues/1791 + # pass through the parent callback to all children + callback.parent = self.parent + def _on_training_start(self) -> None: for callback in self.callbacks: callback.on_training_start(self.locals, self.globals) @@ -606,7 +610,7 @@ def __init__(self, max_episodes: int, verbose: int = 0): self.n_episodes = 0 def _init_callback(self) -> None: - # At start set total max according to number of envirnments + # At start set total max according to number of environments self._total_max_episodes = self.max_episodes * self.training_env.num_envs def _on_step(self) -> bool: diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 149345d83..132a35348 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -113,7 +113,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor: so we can sum components of the ``log_prob`` or the entropy. :param tensor: shape: (n_batch, n_actions) or (n_batch,) - :return: shape: (n_batch,) + :return: shape: (n_batch,) for (n_batch, n_actions) input, scalar for (n_batch,) input """ if len(tensor.shape) > 1: tensor = tensor.sum(dim=1) diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index dc465a1d6..090d609ba 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -17,13 +17,37 @@ def _is_numpy_array_space(space: spaces.Space) -> bool: return not isinstance(space, (spaces.Dict, spaces.Tuple)) +def _starts_at_zero(space: Union[spaces.Discrete, spaces.MultiDiscrete]) -> bool: + """ + Return False if a (Multi)Discrete space has a non-zero start. + """ + return np.allclose(space.start, np.zeros_like(space.start)) + + +def _check_non_zero_start(space: spaces.Space, space_type: str = "observation", key: str = "") -> None: + """ + :param space: Observation or action space + :param space_type: information about whether it is an observation or action space + (for the warning message) + :param key: When the observation space comes from a Dict space, we pass the + corresponding key to have more precise warning messages. Defaults to "". + """ + if isinstance(space, (spaces.Discrete, spaces.MultiDiscrete)) and not _starts_at_zero(space): + maybe_key = f"(key='{key}')" if key else "" + warnings.warn( + f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) " + "is not supported by Stable-Baselines3. " + f"You can use a wrapper or update your {space_type} space." + ) + + def _check_image_input(observation_space: spaces.Box, key: str = "") -> None: """ Check that the input will be compatible with Stable-Baselines when the observation is apparently an image. :param observation_space: Observation space - :key: When the observation space comes from a Dict space, we pass the + :param key: When the observation space comes from a Dict space, we pass the corresponding key to have more precise warning messages. Defaults to "". """ if observation_space.dtype != np.uint8: @@ -63,11 +87,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act for key, space in observation_space.spaces.items(): if isinstance(space, spaces.Dict): nested_dict = True - if isinstance(space, spaces.Discrete) and space.start != 0: - warnings.warn( - f"Discrete observation space (key '{key}') with a non-zero start is not supported by Stable-Baselines3. " - "You can use a wrapper or update your observation space." - ) + _check_non_zero_start(space, "observation", key) if nested_dict: warnings.warn( @@ -87,11 +107,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act "which is supported by SB3." ) - if isinstance(observation_space, spaces.Discrete) and observation_space.start != 0: - warnings.warn( - "Discrete observation space with a non-zero start is not supported by Stable-Baselines3. " - "You can use a wrapper or update your observation space." - ) + _check_non_zero_start(observation_space, "observation") if isinstance(observation_space, spaces.Sequence): warnings.warn( @@ -100,11 +116,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act "Note: The checks for returned values are skipped." ) - if isinstance(action_space, spaces.Discrete) and action_space.start != 0: - warnings.warn( - "Discrete action space with a non-zero start is not supported by Stable-Baselines3. " - "You can use a wrapper or update your action space." - ) + _check_non_zero_start(action_space, "action") if not _is_numpy_array_space(action_space): warnings.warn( @@ -385,7 +397,7 @@ def _check_render(env: gym.Env, warn: bool = False) -> None: # pragma: no cover "you may have trouble when calling `.render()`" ) - # Only check currrent render mode + # Only check current render mode if env.render_mode: env.render() env.close() diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index 5253954e8..fb8ce33c6 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -189,7 +189,7 @@ def __init__( filename = os.path.realpath(filename) # Create (if any) missing filename directories os.makedirs(os.path.dirname(filename), exist_ok=True) - # Append mode when not overridding existing file + # Append mode when not overriding existing file mode = "w" if override_existing else "a" # Prevent newline issue on Windows, see GH issue #692 self.file_handler = open(filename, f"{mode}t", newline="\n") diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index ddd0f8de2..262453721 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -92,6 +92,7 @@ def __init__( use_sde=use_sde, sde_sample_freq=sde_sample_freq, support_multi_env=True, + monitor_wrapper=monitor_wrapper, seed=seed, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, @@ -200,14 +201,14 @@ def collect_rollouts( if not callback.on_step(): return False - self._update_info_buffer(infos) + self._update_info_buffer(infos, dones) n_steps += 1 if isinstance(self.action_space, spaces.Discrete): # Reshape in case of discrete action actions = actions.reshape(-1, 1) - # Handle timeout by bootstraping with value function + # Handle timeout by bootstrapping with value function # see GitHub issue #633 for idx, done in enumerate(dones): if ( @@ -250,6 +251,28 @@ def train(self) -> None: """ raise NotImplementedError + def _dump_logs(self, iteration: int) -> None: + """ + Write log. + + :param iteration: Current logging iteration + """ + assert self.ep_info_buffer is not None + assert self.ep_success_buffer is not None + + time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) + fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) + self.logger.record("time/iterations", iteration, exclude="tensorboard") + if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: + self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) + self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) + self.logger.record("time/fps", fps) + self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") + self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") + if len(self.ep_success_buffer) > 0: + self.logger.record("rollout/success_rate", safe_mean(self.ep_success_buffer)) + self.logger.dump(step=self.num_timesteps) + def learn( self: SelfOnPolicyAlgorithm, total_timesteps: int, @@ -285,16 +308,7 @@ def learn( # Display training infos if log_interval is not None and iteration % log_interval == 0: assert self.ep_info_buffer is not None - time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) - fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) - self.logger.record("time/iterations", iteration, exclude="tensorboard") - if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: - self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) - self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) - self.logger.record("time/fps", fps) - self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") - self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") - self.logger.dump(step=self.num_timesteps) + self._dump_logs(iteration) self.train() diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 50be01c9e..f9c4285dc 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -173,7 +173,9 @@ def load(cls: Type[SelfBaseModel], path: str, device: Union[th.device, str] = "a :return: """ device = get_device(device) - saved_variables = th.load(path, map_location=device) + # Note(antonin): we cannot use `weights_only=True` here because we need to allow + # gymnasium imports for the policy to be loaded successfully + saved_variables = th.load(path, map_location=device, weights_only=False) # Create policy object model = cls(**saved_variables["data"]) @@ -365,7 +367,7 @@ def predict( with th.no_grad(): actions = self._predict(obs_tensor, deterministic=deterministic) # Convert to numpy, and reshape to the original action shape - actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc] + actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc, assignment] if isinstance(self.action_space, spaces.Box): if self.squash_output: @@ -920,7 +922,7 @@ class ContinuousCritic(BaseModel): By default, it creates two critic networks used to reduce overestimation thanks to clipped Q-learning (cf TD3 paper). - :param observation_space: Obervation space + :param observation_space: Observation space :param action_space: Action space :param net_arch: Network architecture :param features_extractor: Network to extract features diff --git a/stable_baselines3/common/results_plotter.py b/stable_baselines3/common/results_plotter.py index 1324557d1..f4c1a7a05 100644 --- a/stable_baselines3/common/results_plotter.py +++ b/stable_baselines3/common/results_plotter.py @@ -46,7 +46,7 @@ def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callabl def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]: """ - Decompose a data frame variable to x ans ys + Decompose a data frame variable to x and ys :param data_frame: the input data :param x_axis: the axis for the x and y output diff --git a/stable_baselines3/common/running_mean_std.py b/stable_baselines3/common/running_mean_std.py index 9dfa4b84c..ac3538c50 100644 --- a/stable_baselines3/common/running_mean_std.py +++ b/stable_baselines3/common/running_mean_std.py @@ -6,7 +6,7 @@ class RunningMeanStd: def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()): """ - Calulates the running mean and std of a data stream + Calculates the running mean and std of a data stream https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm :param epsilon: helps with arithmetic issues diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index 0cbf6d4e2..a85c9c2ec 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -2,6 +2,7 @@ Save util taken from stable_baselines used to serialize data (class parameters) of model classes """ + import base64 import functools import io @@ -446,7 +447,8 @@ def load_from_zip_file( file_content.seek(0) # Load the parameters with the right ``map_location``. # Remove ".pth" ending with splitext - th_object = th.load(file_content, map_location=device) + # Note(antonin): we cannot use weights_only=True, as it breaks with PyTorch 1.13, see GH#1911 + th_object = th.load(file_content, map_location=device, weights_only=False) # "tensors.pth" was renamed "pytorch_variables.pth" in v0.9.0, see PR #138 if file_path == "pytorch_variables.pth" or file_path == "tensors.pth": # PyTorch variables (not state_dicts) diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index ad6c7eef1..234b91551 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple, Type, Union +from typing import Dict, List, Optional, Tuple, Type, Union import gymnasium as gym import torch as th @@ -14,7 +14,7 @@ class BaseFeaturesExtractor(nn.Module): """ Base class that represents a features extractor. - :param observation_space: + :param observation_space: The observation space of the environment :param features_dim: Number of features extracted. """ @@ -26,6 +26,7 @@ def __init__(self, observation_space: gym.Space, features_dim: int = 0) -> None: @property def features_dim(self) -> int: + """The number of features that the extractor outputs.""" return self._features_dim @@ -34,7 +35,7 @@ class FlattenExtractor(BaseFeaturesExtractor): Feature extract that flatten the input. Used as a placeholder when feature extraction is not needed. - :param observation_space: + :param observation_space: The observation space of the environment """ def __init__(self, observation_space: gym.Space) -> None: @@ -52,7 +53,7 @@ class NatureCNN(BaseFeaturesExtractor): "Human-level control through deep reinforcement learning." Nature 518.7540 (2015): 529-533. - :param observation_space: + :param observation_space: The observation space of the environment :param features_dim: Number of features extracted. This corresponds to the number of unit for the last layer. :param normalized_image: Whether to assume that the image is already normalized @@ -113,13 +114,15 @@ def create_mlp( activation_fn: Type[nn.Module] = nn.ReLU, squash_output: bool = False, with_bias: bool = True, + pre_linear_modules: Optional[List[Type[nn.Module]]] = None, + post_linear_modules: Optional[List[Type[nn.Module]]] = None, ) -> List[nn.Module]: """ Create a multi layer perceptron (MLP), which is a collection of fully-connected layers each followed by an activation function. :param input_dim: Dimension of the input vector - :param output_dim: + :param output_dim: Dimension of the output (last layer, for instance, the number of actions) :param net_arch: Architecture of the neural net It represents the number of units per layer. The length of this list is the number of layers. @@ -128,20 +131,52 @@ def create_mlp( :param squash_output: Whether to squash the output using a Tanh activation function :param with_bias: If set to False, the layers will not learn an additive bias - :return: + :param pre_linear_modules: List of nn.Module to add before the linear layers. + These modules should maintain the input tensor dimension (e.g. BatchNorm). + The number of input features is passed to the module's constructor. + Compared to post_linear_modules, they are used before the output layer (output_dim > 0). + :param post_linear_modules: List of nn.Module to add after the linear layers + (and before the activation function). These modules should maintain the input + tensor dimension (e.g. Dropout, LayerNorm). They are not used after the + output layer (output_dim > 0). The number of input features is passed to + the module's constructor. + :return: The list of layers of the neural network """ + pre_linear_modules = pre_linear_modules or [] + post_linear_modules = post_linear_modules or [] + + modules = [] if len(net_arch) > 0: - modules = [nn.Linear(input_dim, net_arch[0], bias=with_bias), activation_fn()] - else: - modules = [] + # BatchNorm maintains input dim + for module in pre_linear_modules: + modules.append(module(input_dim)) + + modules.append(nn.Linear(input_dim, net_arch[0], bias=with_bias)) + + # LayerNorm, Dropout maintain output dim + for module in post_linear_modules: + modules.append(module(net_arch[0])) + + modules.append(activation_fn()) for idx in range(len(net_arch) - 1): + for module in pre_linear_modules: + modules.append(module(net_arch[idx])) + modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1], bias=with_bias)) + + for module in post_linear_modules: + modules.append(module(net_arch[idx + 1])) + modules.append(activation_fn()) if output_dim > 0: last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim + # Only add BatchNorm before output layer + for module in pre_linear_modules: + modules.append(module(last_layer_dim)) + modules.append(nn.Linear(last_layer_dim, output_dim, bias=with_bias)) if squash_output: modules.append(nn.Tanh()) @@ -189,7 +224,7 @@ def __init__( # save dimensions of layers in policy and value nets if isinstance(net_arch, dict): - # Note: if key is not specificed, assume linear network + # Note: if key is not specified, assume linear network pi_layers_dims = net_arch.get("pi", []) # Layer sizes of the policy network vf_layers_dims = net_arch.get("vf", []) # Layer sizes of the value network else: diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index d75e11531..042c66f9c 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -1,4 +1,5 @@ """Common aliases for type hints""" + from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Protocol, SupportsFloat, Tuple, Union @@ -23,7 +24,7 @@ PyTorchObs = Union[th.Tensor, TensorDict] # A schedule takes the remaining progress as input -# and ouputs a scalar (e.g. learning rate, clip range, ...) +# and outputs a scalar (e.g. learning rate, clip range, ...) Schedule = Callable[[float], float] diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 3ff193786..bcde1cfa0 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -92,7 +92,9 @@ def get_schedule_fn(value_schedule: Union[Schedule, float]) -> Schedule: value_schedule = constant_fn(float(value_schedule)) else: assert callable(value_schedule) - return value_schedule + # Cast to float to avoid unpickling errors to enable weights_only=True, see GH#1900 + # Some types are have odd behaviors when part of a Schedule, like numpy floats + return lambda progress_remaining: float(value_schedule(progress_remaining)) def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule: diff --git a/stable_baselines3/common/vec_env/patch_gym.py b/stable_baselines3/common/vec_env/patch_gym.py index 2da76a9b2..6ba655ebf 100644 --- a/stable_baselines3/common/vec_env/patch_gym.py +++ b/stable_baselines3/common/vec_env/patch_gym.py @@ -71,7 +71,7 @@ def _convert_space(space: Union["gym.Space", gymnasium.Space]) -> gymnasium.Spac :return: Patched space (gymnasium Space) """ - # Gymnasium space, no convertion to be done + # Gymnasium space, no conversion to be done if isinstance(space, gymnasium.Space): return space diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 2a03d8e70..855f50edc 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -1,6 +1,7 @@ """ Helpers for dealing with vectorized environments. """ + from collections import OrderedDict from typing import Any, Dict, List, Tuple diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py index d412a96a2..daa2b365c 100644 --- a/stable_baselines3/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -29,7 +29,12 @@ def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[st def step_wait( self, - ) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]: + ) -> Tuple[ + Union[np.ndarray, Dict[str, np.ndarray]], + np.ndarray, + np.ndarray, + List[Dict[str, Any]], + ]: observations, rewards, dones, infos = self.venv.step_wait() observations, infos = self.stacked_obs.update(observations, dones, infos) # type: ignore[arg-type] return observations, rewards, dones, infos diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py index 391ce342d..ab1d8403a 100644 --- a/stable_baselines3/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -111,7 +111,7 @@ def _sanity_checks(self) -> None: raise ValueError( f"VecNormalize only supports `gym.spaces.Box` observation spaces but {obs_key} " f"is of type {self.observation_space.spaces[obs_key]}. " - "You should probably explicitely pass the observation keys " + "You should probably explicitly pass the observation keys " " that should be normalized via the `norm_obs_keys` parameter." ) diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index c311b2357..2fe2fdfc4 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -60,11 +60,11 @@ def __init__( learning_rate: Union[float, Schedule] = 1e-3, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, - batch_size: int = 100, + batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = (1, "episode"), - gradient_steps: int = -1, + train_freq: Union[int, Tuple[int, str]] = 1, + gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 42e3d0df0..894ed9f04 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -79,7 +79,7 @@ def __init__( env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 1e-4, buffer_size: int = 1_000_000, # 1e6 - learning_starts: int = 50000, + learning_starts: int = 100, batch_size: int = 32, tau: float = 1.0, gamma: float = 0.99, diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index 9d2cf94df..bfefc8137 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -167,7 +167,7 @@ def _build(self, lr_schedule: Schedule) -> None: # Setup optimizer with initial learning rate self.optimizer = self.optimizer_class( # type: ignore[call-arg] - self.parameters(), + self.q_net.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs, ) diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index 5f0765884..20214e72c 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -255,7 +255,7 @@ def _get_real_samples( Get the samples corresponding to the batch and environment indices. :param batch_indices: Indices of the transitions - :param env_indices: Indices of the envrionments + :param env_indices: Indices of the environments :param env: associated gym VecEnv to normalize the observations/rewards when sampling, defaults to None :return: Samples @@ -294,7 +294,7 @@ def _get_virtual_samples( Get the samples, sample new desired goals and compute new rewards. :param batch_indices: Indices of the transitions - :param env_indices: Indices of the envrionments + :param env_indices: Indices of the environments :param env: associated gym VecEnv to normalize the observations/rewards when sampling, defaults to None :return: Samples, with new desired goals and new rewards @@ -357,7 +357,7 @@ def _sample_goals(self, batch_indices: np.ndarray, env_indices: np.ndarray) -> n Sample goals based on goal_selection_strategy. :param batch_indices: Indices of the transitions - :param env_indices: Indices of the envrionments + :param env_indices: Indices of the environments :return: Sampled goals """ batch_ep_start = self.ep_start[batch_indices, env_indices] @@ -396,7 +396,7 @@ def truncate_last_trajectory(self) -> None: "If you are in the same episode as when the replay buffer was saved,\n" "you should use `truncate_last_trajectory=False` to avoid that issue." ) - # only consider epsiodes that are not finished + # only consider episodes that are not finished for env_idx in np.where(self._current_ep_start != self.pos)[0]: # set done = True for last episodes self.dones[self.pos - 1, env_idx] = True diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index ea7cf5ed4..52ee2eb64 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -210,10 +210,6 @@ def train(self) -> None: # Convert discrete action from float to long actions = rollout_data.actions.long().flatten() - # Re-sample the noise matrix because the log_std has changed - if self.use_sde: - self.policy.reset_noise(self.batch_size) - values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) values = values.flatten() # Normalize advantage diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 97d0ad94e..6185e2992 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -26,7 +26,7 @@ class Actor(BasePolicy): """ Actor network (policy) for SAC. - :param observation_space: Obervation space + :param observation_space: Observation space :param action_space: Action space :param net_arch: Network architecture :param features_extractor: Network to extract features diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index dda6cb31a..a15be0396 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -21,7 +21,7 @@ class Actor(BasePolicy): """ Actor network (policy) for TD3. - :param observation_space: Obervation space + :param observation_space: Observation space :param action_space: Action space :param net_arch: Network architecture :param features_extractor: Network to extract features diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index a06ce67e0..a61d954bc 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -83,11 +83,11 @@ def __init__( learning_rate: Union[float, Schedule] = 1e-3, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, - batch_size: int = 100, + batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = (1, "episode"), - gradient_steps: int = -1, + train_freq: Union[int, Tuple[int, str]] = 1, + gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index c043eea77..f5230e413 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.2.1 +2.4.0a7 diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 2ea366aff..da6b44a34 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -74,7 +74,7 @@ def step(self, action): @pytest.mark.parametrize("env_cls", [DummyEnv, DummyDictEnv]) def test_env(env_cls): # Check the env used for testing - # Do not warn for assymetric space + # Do not warn for asymmetric space check_env(env_cls(), warn=False, skip_render_check=True) @@ -86,7 +86,7 @@ def test_replay_buffer_normalization(replay_buffer_cls): buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device="cpu") - # Interract and store transitions + # Interact and store transitions env.reset() obs = env.get_original_obs() for _ in range(100): @@ -125,7 +125,7 @@ def test_device_buffer(replay_buffer_cls, device): buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device=device) - # Interract and store transitions + # Interact and store transitions obs = env.reset() for _ in range(100): action = env.action_space.sample() diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index d159c43e8..ffc37320f 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -264,3 +264,29 @@ def test_checkpoint_additional_info(tmp_path): model = DQN.load(checkpoint_dir / "rl_model_200_steps.zip") model.load_replay_buffer(checkpoint_dir / "rl_model_replay_buffer_200_steps.pkl") VecNormalize.load(checkpoint_dir / "rl_model_vecnormalize_200_steps.pkl", dummy_vec_env) + + +def test_eval_callback_chaining(tmp_path): + class DummyCallback(BaseCallback): + def _on_step(self): + # Check that the parent callback is an EvalCallback + assert isinstance(self.parent, EvalCallback) + assert hasattr(self.parent, "best_mean_reward") + return True + + stop_on_threshold_callback = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1) + + eval_callback = EvalCallback( + gym.make("Pendulum-v1"), + best_model_save_path=tmp_path, + log_path=tmp_path, + eval_freq=32, + deterministic=True, + render=False, + callback_on_new_best=CallbackList([DummyCallback(), stop_on_threshold_callback]), + callback_after_eval=CallbackList([DummyCallback()]), + warn=False, + ) + + model = PPO("MlpPolicy", "Pendulum-v1", n_steps=64, n_epochs=1) + model.learn(64, callback=eval_callback) diff --git a/tests/test_cnn.py b/tests/test_cnn.py index e32438c27..4ff31486b 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -161,7 +161,7 @@ def test_features_extractor_target_net(model_class, share_features_extractor): if model_class == TD3: assert id(model.policy.actor_target.features_extractor) != id(model.policy.critic_target.features_extractor) - # Critic and target should be equal at the begginning of training + # Critic and target should be equal at the beginning of training params_should_match(model.critic.parameters(), model.critic_target.parameters()) # TD3 has also a target actor net diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index 1f89b23d6..e92ffe8b7 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -1,8 +1,10 @@ import pytest import torch as th +import torch.nn as nn from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike +from stable_baselines3.common.torch_layers import create_mlp @pytest.mark.parametrize( @@ -62,3 +64,57 @@ def test_tf_like_rmsprop_optimizer(): def test_dqn_custom_policy(): policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32]) _ = DQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, learning_starts=100).learn(300) + + +def test_create_mlp(): + net = create_mlp(4, 2, net_arch=[16, 8], squash_output=True) + # We cannot compare the network directly because the modules have different ids + # assert net == [nn.Linear(4, 16), nn.ReLU(), nn.Linear(16, 8), nn.ReLU(), nn.Linear(8, 2), + # nn.Tanh()] + assert len(net) == 6 + assert isinstance(net[0], nn.Linear) + assert net[0].in_features == 4 + assert net[0].out_features == 16 + assert isinstance(net[1], nn.ReLU) + assert isinstance(net[2], nn.Linear) + assert isinstance(net[4], nn.Linear) + assert net[4].in_features == 8 + assert net[4].out_features == 2 + assert isinstance(net[5], nn.Tanh) + + # Linear network + net = create_mlp(4, -1, net_arch=[]) + assert net == [] + + # No output layer, with custom activation function + net = create_mlp(6, -1, net_arch=[8], activation_fn=nn.Tanh) + # assert net == [nn.Linear(6, 8), nn.Tanh()] + assert len(net) == 2 + assert isinstance(net[0], nn.Linear) + assert net[0].in_features == 6 + assert net[0].out_features == 8 + assert isinstance(net[1], nn.Tanh) + + # Using pre-linear and post-linear modules + pre_linear = [nn.BatchNorm1d] + post_linear = [nn.LayerNorm] + net = create_mlp(6, 2, net_arch=[8, 12], pre_linear_modules=pre_linear, post_linear_modules=post_linear) + # assert net == [nn.BatchNorm1d(6), nn.Linear(6, 8), nn.LayerNorm(8), nn.ReLU() + # nn.BatchNorm1d(6), nn.Linear(8, 12), nn.LayerNorm(12), nn.ReLU(), + # nn.BatchNorm1d(12), nn.Linear(12, 2)] # Last layer does not have post_linear + assert len(net) == 10 + assert isinstance(net[0], nn.BatchNorm1d) + assert net[0].num_features == 6 + assert isinstance(net[1], nn.Linear) + assert isinstance(net[2], nn.LayerNorm) + assert isinstance(net[3], nn.ReLU) + assert isinstance(net[4], nn.BatchNorm1d) + assert isinstance(net[5], nn.Linear) + assert net[5].in_features == 8 + assert net[5].out_features == 12 + assert isinstance(net[6], nn.LayerNorm) + assert isinstance(net[7], nn.ReLU) + assert isinstance(net[8], nn.BatchNorm1d) + assert isinstance(net[-1], nn.Linear) + assert net[-1].in_features == 12 + assert net[-1].out_features == 2 diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 14777452e..f093e47e7 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -326,7 +326,7 @@ def test_vec_normalize(model_class): def test_dict_nested(): """ - Make sure we throw an appropiate error with nested Dict observation spaces + Make sure we throw an appropriate error with nested Dict observation spaces """ # Test without manual wrapping to vec-env env = DummyDictEnv(nested_dict_obs=True) diff --git a/tests/test_envs.py b/tests/test_envs.py index e82ef5768..9a61eeef0 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -123,6 +123,8 @@ def patched_step(_action): spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}), # Non zero start index spaces.Discrete(3, start=-1), + # Non zero start index (MultiDiscrete) + spaces.MultiDiscrete([4, 4], start=[1, 0]), # Non zero start index inside a Dict spaces.Dict({"obs": spaces.Discrete(3, start=1)}), ], @@ -164,6 +166,8 @@ def patched_step(_action): spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32), # Non zero start index spaces.Discrete(3, start=-1), + # Non zero start index (MultiDiscrete) + spaces.MultiDiscrete([4, 4], start=[1, 0]), ], ) def test_non_default_action_spaces(new_action_space): @@ -179,7 +183,7 @@ def test_non_default_action_spaces(new_action_space): env.action_space = new_action_space # Discrete action space - if isinstance(new_action_space, spaces.Discrete): + if isinstance(new_action_space, (spaces.Discrete, spaces.MultiDiscrete)): with pytest.warns(UserWarning): check_env(env) return diff --git a/tests/test_her.py b/tests/test_her.py index 79b0ac9c6..cb8bfb10f 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -384,7 +384,7 @@ def env_fn(): # for all episodes that are not finished before truncate_last_trajectory: timeouts should be 1 if handle_timeout_termination: assert (replay_buffer.timeouts[pos - 1, env_idx_not_finished] == 1).all() - # episode length sould be != 0 -> episode can be sampled + # episode length should be != 0 -> episode can be sampled assert (replay_buffer.ep_length[pos - 1] != 0).all() # replay buffer should not have changed after truncate_last_trajectory (except dones[pos-1]) diff --git a/tests/test_logger.py b/tests/test_logger.py index 790cde244..90ae0c9f4 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -15,7 +15,7 @@ from matplotlib import pyplot as plt from pandas.errors import EmptyDataError -from stable_baselines3 import A2C, DQN +from stable_baselines3 import A2C, DQN, PPO from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.logger import ( DEBUG, @@ -34,6 +34,7 @@ read_csv, read_json, ) +from stable_baselines3.common.monitor import Monitor KEY_VALUES = { "test": 1, @@ -534,3 +535,92 @@ def get_printed(self) -> str: """ assert printed == desired_printed + + +class DummySuccessEnv(gym.Env): + """ + Create a dummy success environment that returns whether True or False for info['is_success'] + at the end of an episode according to its dummy successes list + """ + + def __init__(self, dummy_successes, ep_steps): + """Init the dummy success env + + :param dummy_successes: list of size (n_logs_iterations, n_episodes_per_log) that specifies + the success value of log iteration i at episode j + :param ep_steps: number of steps per episode (to activate truncated) + """ + self.n_steps = 0 + self.log_id = 0 + self.ep_id = 0 + + self.ep_steps = ep_steps + + self.dummy_success = dummy_successes + self.num_logs = len(dummy_successes) + self.ep_per_log = len(dummy_successes[0]) + self.steps_per_log = self.ep_per_log * self.ep_steps + + self.action_space = spaces.Discrete(2) + self.observation_space = spaces.Discrete(2) + + def reset(self, seed=None, options=None): + """ + Reset the env and advance to the next episode_id to get the next dummy success + """ + self.n_steps = 0 + + if self.ep_id == self.ep_per_log: + self.ep_id = 0 + self.log_id = (self.log_id + 1) % self.num_logs + + return self.observation_space.sample(), {} + + def step(self, action): + """ + Step and return a dummy success when an episode is truncated + """ + self.n_steps += 1 + truncated = self.n_steps >= self.ep_steps + + info = {} + if truncated: + maybe_success = self.dummy_success[self.log_id][self.ep_id] + info["is_success"] = maybe_success + self.ep_id += 1 + return self.observation_space.sample(), 0.0, False, truncated, info + + +def test_rollout_success_rate_on_policy_algorithm(tmp_path): + """ + Test if the rollout/success_rate information is correctly logged with on policy algorithms + + To do so, create a dummy environment that takes as argument dummy successes (i.e when an episode) + is going to be successful or not. + """ + + STATS_WINDOW_SIZE = 10 + # Add dummy successes with 0.3, 0.5 and 0.8 success_rate of length STATS_WINDOW_SIZE + dummy_successes = [ + [True] * 3 + [False] * 7, + [True] * 5 + [False] * 5, + [True] * 8 + [False] * 2, + ] + ep_steps = 64 + + # Monitor the env to track the success info + monitor_file = str(tmp_path / "monitor.csv") + env = Monitor(DummySuccessEnv(dummy_successes, ep_steps), filename=monitor_file, info_keywords=("is_success",)) + + # Equip the model of a custom logger to check the success_rate info + model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=env.steps_per_log, verbose=1) + logger = InMemoryLogger() + model.set_logger(logger) + + # Make the model learn and check that the success rate corresponds to the ratio of dummy successes + model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + assert logger.name_to_value["rollout/success_rate"] == 0.3 + model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + assert logger.name_to_value["rollout/success_rate"] == 0.5 + model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1) + assert logger.name_to_value["rollout/success_rate"] == 0.8 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index e7123e984..962088246 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -340,7 +340,7 @@ def test_save_load_env_cnn(tmp_path, model_class): # clear file from os os.remove(tmp_path / "test_save.zip") - # Check we can load models saved with SB3 < 1.7.0 + # Check we can load A2C/PPO models saved with SB3 < 1.7.0 if model_class == A2C: del model.policy.pi_features_extractor model.save(tmp_path / "test_save") @@ -783,3 +783,41 @@ def test_no_resource_warning(tmp_path): fp.seek(0) model.load_replay_buffer(fp) assert not fp.closed + + +def test_cast_lr_schedule(tmp_path): + # See GH#1900 + model = PPO("MlpPolicy", "Pendulum-v1", learning_rate=lambda t: t * np.sin(1.0)) + # Note: for recent version of numpy, np.float64 is a subclass of float + # so we need to use type here + # assert isinstance(model.lr_schedule(1.0), float) + assert type(model.lr_schedule(1.0)) is float + assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0)) + model.save(tmp_path / "ppo.zip") + model = PPO.load(tmp_path / "ppo.zip") + assert type(model.lr_schedule(1.0)) is float + assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0)) + + +def test_save_load_net_arch_none(tmp_path): + """ + Test that the model is loaded correctly when net_arch is manually set to None. + See GH#1928 + """ + PPO("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=None)).save(tmp_path / "ppo.zip") + model = PPO.load(tmp_path / "ppo.zip") + # None has been replaced by the default net arch + assert model.policy.net_arch is not None + os.remove(tmp_path / "ppo.zip") + + +def test_save_load_no_target_params(tmp_path): + # Check we can load DQN models saved with SB3 < 2.4.0 + model = DQN("MlpPolicy", "CartPole-v1", buffer_size=10000, learning_starts=4) + env = model.get_env() + # Include target net params + model.policy.optimizer = th.optim.Adam(model.policy.parameters(), lr=0.001) + model.save(tmp_path / "test_save") + with pytest.warns(UserWarning): + DQN.load(str(tmp_path / "test_save.zip"), env=env).learn(20) + os.remove(tmp_path / "test_save.zip") diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 2b30d5ad1..b7d71b748 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -484,7 +484,7 @@ def test_non_dict_obs_keys(): with pytest.raises(ValueError, match=".*is applicable only.*"): _make_warmstart(lambda: DummyRewardEnv(), norm_obs_keys=["key"]) - with pytest.raises(ValueError, match=".* explicitely pass the observation keys.*"): + with pytest.raises(ValueError, match=".* explicitly pass the observation keys.*"): _make_warmstart(lambda: DummyMixedDictEnv()) # Ignore Discrete observation key