Skip to content

Commit

Permalink
Merge branch 'master' into prioritized-experience-replay
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin authored Nov 5, 2024
2 parents 5c0c79d + 8f0b488 commit 148e4aa
Show file tree
Hide file tree
Showing 36 changed files with 549 additions and 222 deletions.
83 changes: 46 additions & 37 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ name: CI

on:
push:
branches: [ master ]
branches: [master]
pull_request:
branches: [ master ]
branches: [master]

jobs:
build:
Expand All @@ -21,40 +21,49 @@ jobs:
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]

include:
# Default version
- gymnasium-version: "1.0.0"
# Add a new config to test gym<1.0
- python-version: "3.10"
gymnasium-version: "0.29.1"
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
# Install Atari Roms
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Use uv for faster downloads
pip install uv
# cpu version of pytorch
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu
pip install .[extra_no_roms,tests,docs]
# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
run: |
make lint
- name: Build the doc
run: |
make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
- name: Test with pytest
run: |
make pytest
uv pip install --system .[extra,tests,docs]
# Use headless version
uv pip install --system opencv-python-headless
- name: Install specific version of gym
run: |
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
# Only run for python 3.10, downgrade gym to 0.29.1
if: matrix.gymnasium-version != '1.0.0'
- name: Lint with ruff
run: |
make lint
- name: Build the doc
run: |
make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
# Do not run for python 3.8 (mypy internal error)
if: matrix.python-version != '3.8'
- name: Test with pytest
run: |
make pytest
4 changes: 2 additions & 2 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ conda:
environment: docs/conda_env.yml

build:
os: ubuntu-22.04
os: ubuntu-24.04
tools:
python: "mambaforge-22.9"
python: "mambaforge-23.11"
2 changes: 1 addition & 1 deletion CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
identity and expression, level of experience, education, socioeconomic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.

Expand Down
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ into two categories:
- Create an issue about your intended feature, and we shall discuss the design and
implementation. Once we agree that the plan looks good, go ahead and implement it.
2. You want to implement a feature or bug-fix for an outstanding issue
- Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/issues
- Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted
- Pick an issue or feature and comment on the task that you want to work on this feature.
- If you need more context on a particular issue, please ask, and we shall provide.

Expand Down
36 changes: 23 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
<img src="docs/\_static/img/logo.png" align="right" width="40%"/>

<!-- [![pipeline status](https://gitlab.com/araffin/stable-baselines3/badges/master/pipeline.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) -->
![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg)
[![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master)
[![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml)
[![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml)
[![codestyle](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)


# Stable Baselines3

<img src="docs/\_static/img/logo.png" align="right" width="40%"/>

Stable Baselines3 (SB3) is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).

You can read a detailed presentation of Stable Baselines3 in the [v1.0 blog post](https://araffin.github.io/post/sb3/) or our [JMLR paper](https://jmlr.org/papers/volume22/20-1364/20-1364.pdf).
Expand All @@ -22,6 +22,8 @@ These algorithms will make it easier for the research community and industry to
**The performance of each algorithm was tested** (see *Results* section in their respective page),
you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details.

We also provide detailed logs and reports on the [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sb3) platform.


| **Features** | **Stable-Baselines3** |
| --------------------------- | ----------------------|
Expand All @@ -41,7 +43,13 @@ you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselin

### Planned features

Please take a look at the [Roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) and [Milestones](https://github.com/DLR-RM/stable-baselines3/milestones).
Since most of the features from the [original roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) have been implemented, there are no major changes planned for SB3, it is now *stable*.
If you want to contribute, you can search in the issues for the ones where [help is welcomed](https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted) and the other [proposed enhancements](https://github.com/DLR-RM/stable-baselines3/labels/enhancement).

While SB3 development is now focused on bug fixes and maintenance (doc update, user experience, ...), there is more active development going on in the associated repositories:
- newer algorithms are regularly added to the [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) repository
- faster variants are developed in the [SBX (SB3 + Jax)](https://github.com/araffin/sbx) repository
- the training framework for SB3, the RL Zoo, has an active [roadmap](https://github.com/DLR-RM/rl-baselines3-zoo/issues/299)

## Migration guide: from Stable-Baselines (SB2) to Stable-Baselines3 (SB3)

Expand Down Expand Up @@ -79,7 +87,7 @@ Documentation: https://rl-baselines3-zoo.readthedocs.io/en/master/

We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)

This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).
This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), CrossQ, Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).

Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/)

Expand All @@ -97,17 +105,16 @@ It provides a minimal number of features compared to SB3 but can be much faster
### Prerequisites
Stable Baselines3 requires Python 3.8+.

#### Windows 10
#### Windows

To install stable-baselines on Windows, please look at the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/install.html#prerequisites).


### Install using pip
Install the Stable Baselines3 package:
```sh
pip install 'stable-baselines3[extra]'
```
pip install stable-baselines3[extra]
```
**Note:** Some shells such as Zsh require quotation marks around brackets, i.e. `pip install 'stable-baselines3[extra]'` ([More Info](https://stackoverflow.com/a/30539963)).

This includes an optional dependencies like Tensorboard, OpenCV or `ale-py` to train on atari games. If you do not need those, you can use:
```sh
Expand Down Expand Up @@ -177,6 +184,7 @@ All the following examples can be executed online using Google Colab notebooks:
| ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- |
| ARS<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| CrossQ<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| DDPG | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| DQN | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
Expand All @@ -191,7 +199,7 @@ All the following examples can be executed online using Google Colab notebooks:

<b id="f1">1</b>: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) GitHub repository.

Actions `gym.spaces`:
Actions `gymnasium.spaces`:
* `Box`: A N-dimensional box that contains every point in the action space.
* `Discrete`: A list of possible actions, where each timestep only one of the actions can be used.
* `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used.
Expand All @@ -218,9 +226,9 @@ To run a single test:
python3 -m pytest -v -k 'test_check_env_dict_action'
```

You can also do a static type check using `pytype` and `mypy`:
You can also do a static type check using `mypy`:
```sh
pip install pytype mypy
pip install mypy
make type
```

Expand Down Expand Up @@ -252,6 +260,8 @@ To cite this repository in publications:
}
```

Note: If you need to refer to a specific version of SB3, you can also use the [Zenodo DOI](https://doi.org/10.5281/zenodo.8123988).

## Maintainers

Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave), [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli) and [Quentin Gallouédec](https://gallouedec.com/) (@qgallouedec).
Expand Down
12 changes: 6 additions & 6 deletions docs/conda_env.yml
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
name: root
channels:
- pytorch
- defaults
- conda-forge
dependencies:
- cpuonly=1.0=0
- pip=22.3.1
- python=3.8
- pytorch=1.13.0=py3.8_cpu_0
- pip=24.2
- python=3.11
- pytorch=2.5.0=py3.11_cpu_0
- pip:
- gymnasium
- gymnasium>=0.29.1,<1.1.0
- cloudpickle
- opencv-python-headless
- pandas
- numpy
- numpy>=1.20,<2.0
- matplotlib
- sphinx>=5,<8
- sphinx_rtd_theme>=1.3.0
Expand Down
1 change: 1 addition & 0 deletions docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary``
=================== =========== ============ ================= =============== ================
ARS [#f1]_ ✔️ ✔️ ❌ ❌ ✔️
A2C ✔️ ✔️ ✔️ ✔️ ✔️
CrossQ [#f1]_ ✔️ ❌ ❌ ❌ ✔️
DDPG ✔️ ❌ ❌ ❌ ✔️
DQN ❌ ✔️ ❌ ❌ ✔️
HER ✔️ ✔️ ❌ ❌ ✔️
Expand Down
1 change: 1 addition & 0 deletions docs/guide/sb3_contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ See documentation for the full list of included features.
- `PPO with recurrent policy (RecurrentPPO aka PPO LSTM) <https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/>`_
- `Truncated Quantile Critics (TQC)`_
- `Trust Region Policy Optimization (TRPO) <https://arxiv.org/abs/1502.05477>`_
- `Batch Normalization in Deep Reinforcement Learning (CrossQ) <https://openreview.net/forum?id=PczQtTsTIX>`_


**Gym Wrappers**:
Expand Down
4 changes: 3 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,14 @@ To cite this project in publications:
url = {http://jmlr.org/papers/v22/20-1364.html}
}
Note: If you need to refer to a specific version of SB3, you can also use the `Zenodo DOI <https://doi.org/10.5281/zenodo.8123988>`_.

Contributing
------------

To any interested in making the rl baselines better, there are still some improvements
that need to be done.
You can check issues in the `repo <https://github.com/DLR-RM/stable-baselines3/issues>`_.
You can check issues in the `repository <https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted>`_.

If you want to contribute, please read `CONTRIBUTING.md <https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md>`_ first.

Expand Down
37 changes: 36 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,35 @@
Changelog
==========

Release 2.4.0a5 (WIP)
Release 2.4.0a11 (WIP)
--------------------------

**New algorithm: CrossQ in SB3 Contrib, Gymnasium v1.0 support**

.. note::

DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about
truncation of optimizer state when loaded with SB3 >= 2.4.0.
To suppress the warning, simply save the model again.
You can find more info in `PR #1963 <https://github.com/DLR-RM/stable-baselines3/pull/1963>`_

.. warning::

Stable-Baselines3 (SB3) v2.4.0 will be the last one supporting Python 3.8 (end of life in October 2024)
and PyTorch < 2.0.
We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.0.


Breaking Changes:
^^^^^^^^^^^^^^^^^
- Increase minimum required version of Gymnasium to 0.29.1

New Features:
^^^^^^^^^^^^^
- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ)
- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle)
- Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces
- Added support for Gymnasium v1.0

Bug Fixes:
^^^^^^^^^^
Expand All @@ -21,15 +42,23 @@ Bug Fixes:
- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122)
- Set requirement numpy<2.0 until PyTorch is compatible (https://github.com/pytorch/pytorch/issues/107302)
- Updated DQN optimizer input to only include q_network parameters, removing the target_q_network ones (@corentinlger)
- Fixed ``test_buffers.py::test_device`` which was not actually checking the device of tensors (@rhaps0dy)


`SB3-Contrib`_
^^^^^^^^^^^^^^
- Added ``CrossQ`` algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen)
- Added ``BatchRenorm`` PyTorch layer used in ``CrossQ`` (@danielpalen)
- Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger)
- Fixed loading QRDQN changes `target_update_interval` (@jak3122)

`RL Zoo`_
^^^^^^^^^
- Updated defaults hyperparameters for TQC/SAC for Swimmer-v4 (decrease gamma for more consistent results)

`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^
- Added CNN support for DQN

Deprecations:
^^^^^^^^^^^^^
Expand All @@ -39,12 +68,18 @@ Others:
- Fixed various typos (@cschindlbeck)
- Remove unnecessary SDE noise resampling in PPO update (@brn-dev)
- Updated PyTorch version on CI to 2.3.1
- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy``
- Switched to uv to download packages faster on GitHub CI
- Updated dependencies for read the doc
- Removed unnecessary ``copy_obs_dict`` method for ``SubprocVecEnv``, remove the use of ordered dict and rename ``flatten_obs`` to ``stack_obs``

Bug Fixes:
^^^^^^^^^^

Documentation:
^^^^^^^^^^^^^^
- Updated PPO doc to recommend using CPU with ``MlpPolicy``
- Clarified documentation about planned features and citing software

Release 2.3.2 (2024-04-27)
--------------------------
Expand Down
17 changes: 17 additions & 0 deletions docs/modules/ppo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,23 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments.
vec_env.render("human")
.. note::

PPO is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using ``SubprocVecEnv`` instead of the default ``DummyVecEnv``:

.. code-block::
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
if __name__=="__main__":
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
model = PPO("MlpPolicy", env, device="cpu")
model.learn(total_timesteps=25_000)
For more information, see :ref:`Vectorized Environments <vec_env>`, `Issue #1245 <https://github.com/DLR-RM/stable-baselines3/issues/1245#issuecomment-1435766949>`_ or the `Multiprocessing notebook <https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb>`_.

Results
-------

Expand Down
Loading

0 comments on commit 148e4aa

Please sign in to comment.