Skip to content

Commit

Permalink
Merge branch 'master' into feat/mps-support
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin authored Sep 18, 2024
2 parents b85a2a5 + 512eea9 commit 955382e
Show file tree
Hide file tree
Showing 45 changed files with 539 additions and 167 deletions.
72 changes: 37 additions & 35 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ name: CI

on:
push:
branches: [ master ]
branches: [master]
pull_request:
branches: [ master ]
branches: [master]

jobs:
build:
Expand All @@ -23,38 +23,40 @@ jobs:
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
# Install Atari Roms
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
# Install Atari Roms
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
pip install .[extra_no_roms,tests,docs]
# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
run: |
make lint
- name: Build the doc
run: |
make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
- name: Test with pytest
run: |
make pytest
pip install .[extra_no_roms,tests,docs]
# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
run: |
make lint
- name: Build the doc
run: |
make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
# Do not run for python 3.8 (mypy internal error)
if: matrix.python-version != '3.8'
- name: Test with pytest
run: |
make pytest
2 changes: 1 addition & 1 deletion CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
identity and expression, level of experience, education, socioeconomic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.

Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ lint:
# see https://www.flake8rules.com/
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff check ${LINT_PATHS} --exit-zero
ruff check ${LINT_PATHS} --exit-zero --output-format=concise

format:
# Sort imports
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
<img src="docs/\_static/img/logo.png" align="right" width="40%"/>

<!-- [![pipeline status](https://gitlab.com/araffin/stable-baselines3/badges/master/pipeline.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) -->
![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg)
[![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master)
Expand All @@ -8,6 +6,8 @@

# Stable Baselines3

<img src="docs/\_static/img/logo.png" align="right" width="40%"/>

Stable Baselines3 (SB3) is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).

You can read a detailed presentation of Stable Baselines3 in the [v1.0 blog post](https://araffin.github.io/post/sb3/) or our [JMLR paper](https://jmlr.org/papers/volume22/20-1364/20-1364.pdf).
Expand Down Expand Up @@ -85,7 +85,7 @@ Documentation is available online: [https://sb3-contrib.readthedocs.io/](https:/

## Stable-Baselines Jax (SBX)

[Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax.
[Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax, with recent algorithms like DroQ or CrossQ.

It provides a minimal number of features compared to SB3 but can be much faster (up to 20x times!): https://twitter.com/araffin2/status/1590714558628253698

Expand Down Expand Up @@ -192,7 +192,7 @@ All the following examples can be executed online using Google Colab notebooks:
<b id="f1">1</b>: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) GitHub repository.

Actions `gym.spaces`:
* `Box`: A N-dimensional box that containes every point in the action space.
* `Box`: A N-dimensional box that contains every point in the action space.
* `Discrete`: A list of possible actions, where each timestep only one of the actions can be used.
* `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used.
* `MultiBinary`: A list of possible actions, where each timestep any of the actions can be used in any combination.
Expand Down
3 changes: 2 additions & 1 deletion docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ Actions ``gym.spaces``:

.. note::

More algorithms (like QR-DQN or TQC) are implemented in our :ref:`contrib repo <sb3_contrib>`.
More algorithms (like QR-DQN or TQC) are implemented in our :ref:`contrib repo <sb3_contrib>`
and in our :ref:`SBX (SB3 + Jax) repo <sbx>` (DroQ, CrossQ, ...).

.. note::

Expand Down
10 changes: 5 additions & 5 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
:param env_id: the environment ID
:param num_env: the number of environments you wish to have in subprocesses
:param seed: the inital seed for RNG
:param seed: the initial seed for RNG
:param rank: index of the subprocess
"""
def _init():
Expand Down Expand Up @@ -179,9 +179,9 @@ Multiprocessing with off-policy algorithms
vec_env = make_vec_env("Pendulum-v0", n_envs=4, seed=0)
# We collect 4 transitions per call to `ènv.step()`
# and performs 2 gradient steps per call to `ènv.step()`
# if gradient_steps=-1, then we would do 4 gradients steps per call to `ènv.step()`
# We collect 4 transitions per call to `env.step()`
# and performs 2 gradient steps per call to `env.step()`
# if gradient_steps=-1, then we would do 4 gradients steps per call to `env.step()`
model = SAC("MlpPolicy", vec_env, train_freq=1, gradient_steps=2, verbose=1)
model.learn(total_timesteps=10_000)
Expand Down Expand Up @@ -436,7 +436,7 @@ will compute a running average and standard deviation of input features (it can
log_dir = "/tmp/"
model.save(log_dir + "ppo_halfcheetah")
stats_path = os.path.join(log_dir, "vec_normalize.pkl")
env.save(stats_path)
vec_env.save(stats_path)
# To demonstrate loading
del model, vec_env
Expand Down
43 changes: 22 additions & 21 deletions docs/guide/rl_tips.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Reinforcement Learning Tips and Tricks
======================================

The aim of this section is to help you do reinforcement learning experiments.
The aim of this section is to help you run reinforcement learning experiments.
It covers general advice about RL (where to start, which algorithm to choose, how to evaluate an algorithm, ...),
as well as tips and tricks when using a custom environment or implementing an RL algorithm.

Expand All @@ -14,6 +14,11 @@ as well as tips and tricks when using a custom environment or implementing an RL
this section in more details. You can also find the `slides here <https://araffin.github.io/slides/rlvs-tips-tricks/>`_.


.. note::

We also have a `video on Designing and Running Real-World RL Experiments <https://youtu.be/eZ6ZEpCi6D8>`_, slides `can be found online <https://araffin.github.io/slides/design-real-rl-experiments/>`_.


General advice when using Reinforcement Learning
================================================

Expand Down Expand Up @@ -103,19 +108,19 @@ and this `issue <https://github.com/hill-a/stable-baselines/issues/199>`_ by Cé
Which algorithm should I use?
=============================

There is no silver bullet in RL, depending on your needs and problem, you may choose one or the other.
There is no silver bullet in RL, you can choose one or the other depending on your needs and problems.
The first distinction comes from your action space, i.e., do you have discrete (e.g. LEFT, RIGHT, ...)
or continuous actions (ex: go to a certain speed)?

Some algorithms are only tailored for one or the other domain: ``DQN`` only supports discrete actions, where ``SAC`` is restricted to continuous actions.
Some algorithms are only tailored for one or the other domain: ``DQN`` supports only discrete actions, while ``SAC`` is restricted to continuous actions.

The second difference that will help you choose is whether you can parallelize your training or not.
The second difference that will help you decide is whether you can parallelize your training or not.
If what matters is the wall clock training time, then you should lean towards ``A2C`` and its derivatives (PPO, ...).
Take a look at the `Vectorized Environments <vec_envs.html>`_ to learn more about training with multiple workers.

To accelerate training, you can also take a look at `SBX`_, which is SB3 + Jax, it has fewer features than SB3 but can be up to 20x faster than SB3 PyTorch thanks to JIT compilation of the gradient update.
To accelerate training, you can also take a look at `SBX`_, which is SB3 + Jax, it has less features than SB3 but can be up to 20x faster than SB3 PyTorch thanks to JIT compilation of the gradient update.

In sparse reward settings, we either recommend to use dedicated methods like HER (see below) or population-based algorithms like ARS (available in our :ref:`contrib repo <sb3_contrib>`).
In sparse reward settings, we either recommend using either dedicated methods like HER (see below) or population-based algorithms like ARS (available in our :ref:`contrib repo <sb3_contrib>`).

To sum it up:

Expand Down Expand Up @@ -146,7 +151,7 @@ Continuous Actions
Continuous Actions - Single Process
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Current State Of The Art (SOTA) algorithms are ``SAC``, ``TD3`` and ``TQC`` (available in our :ref:`contrib repo <sb3_contrib>`).
Current State Of The Art (SOTA) algorithms are ``SAC``, ``TD3``, ``CrossQ`` and ``TQC`` (available in our :ref:`contrib repo <sb3_contrib>` and :ref:`SBX (SB3 + Jax) repo <sbx>`).
Please use the hyperparameters in the `RL zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_ for best results.

If you want an extremely sample-efficient algorithm, we recommend using the `DroQ configuration <https://twitter.com/araffin2/status/1575439865222660098>`_ in `SBX`_ (it does many gradient steps per step in the environment).
Expand All @@ -155,8 +160,7 @@ If you want an extremely sample-efficient algorithm, we recommend using the `Dro
Continuous Actions - Multiprocessed
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Take a look at ``PPO``, ``TRPO`` (available in our :ref:`contrib repo <sb3_contrib>`) or ``A2C``. Again, don't forget to take the hyperparameters from the `RL zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_
for continuous actions problems (cf *Bullet* envs).
Take a look at ``PPO``, ``TRPO`` (available in our :ref:`contrib repo <sb3_contrib>`) or ``A2C``. Again, don't forget to take the hyperparameters from the `RL zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_ for continuous actions problems (cf *Bullet* envs).

.. note::

Expand All @@ -181,26 +185,23 @@ Tips and Tricks when creating a custom environment
==================================================

If you want to learn about how to create a custom environment, we recommend you read this `page <custom_env.html>`_.
We also provide a `colab notebook <https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/master/5_custom_gym_env.ipynb>`_ for
a concrete example of creating a custom gym environment.
We also provide a `colab notebook <https://colab.research.google.com/github/araffin/rl-tutorial-jnrr19/blob/master/5_custom_gym_env.ipynb>`_ for a concrete example of creating a custom gym environment.

Some basic advice:

- always normalize your observation space when you can, i.e., when you know the boundaries
- normalize your action space and make it symmetric when continuous (cf potential issue below) A good practice is to rescale your actions to lie in [-1, 1]. This does not limit you as you can easily rescale the action inside the environment
- start with shaped reward (i.e. informative reward) and simplified version of your problem
- debug with random actions to check that your environment works and follows the gym interface:
- always normalize your observation space if you can, i.e. if you know the boundaries
- normalize your action space and make it symmetric if it is continuous (see potential problem below) A good practice is to rescale your actions so that they lie in [-1, 1]. This does not limit you, as you can easily rescale the action within the environment.
- start with a shaped reward (i.e. informative reward) and a simplified version of your problem
- debug with random actions to check if your environment works and follows the gym interface (with ``check_env``, see below)

Two important things to keep in mind when creating a custom environment is to avoid breaking Markov assumption
Two important things to keep in mind when creating a custom environment are avoiding breaking the Markov assumption
and properly handle termination due to a timeout (maximum number of steps in an episode).
For instance, if there is some time delay between action and observation (e.g. due to wifi communication), you should give a history of observations
as input.
For example, if there is a time delay between action and observation (e.g. due to wifi communication), you should provide a history of observations as input.

Termination due to timeout (max number of steps per episode) needs to be handled separately.
You should return ``truncated = True``.
If you are using the gym ``TimeLimit`` wrapper, this will be done automatically.
You can read `Time Limit in RL <https://arxiv.org/abs/1712.00378>`_ or take a look at the `RL Tips and Tricks video <https://www.youtube.com/watch?v=Ikngt0_DXJg>`_
for more details.
You can read `Time Limit in RL <https://arxiv.org/abs/1712.00378>`_, take a look at the `Designing and Running Real-World RL Experiments video <https://youtu.be/eZ6ZEpCi6D8>`_ or `RL Tips and Tricks video <https://www.youtube.com/watch?v=Ikngt0_DXJg>`_ for more details.


We provide a helper to check that your environment runs without error:
Expand Down Expand Up @@ -234,7 +235,7 @@ If you want to quickly try a random agent on your environment, you can also do:

Most reinforcement learning algorithms rely on a Gaussian distribution (initially centered at 0 with std 1) for continuous actions.
So, if you forget to normalize the action space when using a custom environment,
this can harm learning and be difficult to debug (cf attached image and `issue #473 <https://github.com/hill-a/stable-baselines/issues/473>`_).
this can harm learning and can be difficult to debug (cf attached image and `issue #473 <https://github.com/hill-a/stable-baselines/issues/473>`_).

.. figure:: ../_static/img/mistake.png

Expand Down
15 changes: 9 additions & 6 deletions docs/guide/sbx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Implemented algorithms:
- Deep Q Network (DQN)
- Twin Delayed DDPG (TD3)
- Deep Deterministic Policy Gradient (DDPG)
- Batch Normalization in Deep Reinforcement Learning (CrossQ)


As SBX follows SB3 API, it is also compatible with the `RL Zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_.
Expand All @@ -29,16 +30,17 @@ For that you will need to create two files:
import rl_zoo3
import rl_zoo3.train
from rl_zoo3.train import train
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ
rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
rl_zoo3.ALGOS["droq"] = DroQ
# See SBX readme to use DroQ configuration
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.train.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS
Expand All @@ -56,16 +58,17 @@ Then you can call ``python train_sbx.py --algo sac --env Pendulum-v1`` and use t
import rl_zoo3
import rl_zoo3.enjoy
from rl_zoo3.enjoy import enjoy
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ
rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
rl_zoo3.ALGOS["droq"] = DroQ
# See SBX readme to use DroQ configuration
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.enjoy.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS
Expand Down
6 changes: 5 additions & 1 deletion docs/guide/tensorboard.rst
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ Here is an example of how to render an episode and log the resulting video to Te
import gymnasium as gym
import torch as th
import numpy as np
from stable_baselines3 import A2C
from stable_baselines3.common.callbacks import BaseCallback
Expand Down Expand Up @@ -226,6 +227,9 @@ Here is an example of how to render an episode and log the resulting video to Te
:param _locals: A dictionary containing all local variables of the callback's scope
:param _globals: A dictionary containing all global variables of the callback's scope
"""
# We expect `render()` to return a uint8 array with values in [0, 255] or a float array
# with values in [0, 1], as described in
# https://pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_video
screen = self._eval_env.render(mode="rgb_array")
# PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image convention
screens.append(screen.transpose(2, 0, 1))
Expand All @@ -239,7 +243,7 @@ Here is an example of how to render an episode and log the resulting video to Te
)
self.logger.record(
"trajectory/video",
Video(th.ByteTensor([screens]), fps=40),
Video(th.from_numpy(np.asarray([screens])), fps=40),
exclude=("stdout", "log", "json", "csv"),
)
return True
Expand Down
Loading

0 comments on commit 955382e

Please sign in to comment.