Skip to content

Commit

Permalink
Merge branch 'master' into feat/mps-support
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin authored Oct 29, 2024
2 parents 955382e + dd3d0ac commit 263e657
Show file tree
Hide file tree
Showing 15 changed files with 110 additions and 32 deletions.
11 changes: 7 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,21 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Use uv for faster downloads
pip install uv
# cpu version of pytorch
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu
# Install Atari Roms
pip install autorom
uv pip install --system autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
pip install .[extra_no_roms,tests,docs]
uv pip install --system .[extra_no_roms,tests,docs]
# Use headless version
pip install opencv-python-headless
uv pip install --system opencv-python-headless
- name: Lint with ruff
run: |
make lint
Expand Down
4 changes: 2 additions & 2 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ conda:
environment: docs/conda_env.yml

build:
os: ubuntu-22.04
os: ubuntu-24.04
tools:
python: "mambaforge-22.9"
python: "mambaforge-23.11"
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ into two categories:
- Create an issue about your intended feature, and we shall discuss the design and
implementation. Once we agree that the plan looks good, go ahead and implement it.
2. You want to implement a feature or bug-fix for an outstanding issue
- Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/issues
- Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted
- Pick an issue or feature and comment on the task that you want to work on this feature.
- If you need more context on a particular issue, please ask, and we shall provide.

Expand Down
32 changes: 21 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<!-- [![pipeline status](https://gitlab.com/araffin/stable-baselines3/badges/master/pipeline.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) -->
![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg)
[![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master)
[![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml)
[![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml)
[![codestyle](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)


Expand All @@ -22,6 +22,8 @@ These algorithms will make it easier for the research community and industry to
**The performance of each algorithm was tested** (see *Results* section in their respective page),
you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details.

We also provide detailed logs and reports on the [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sb3) platform.


| **Features** | **Stable-Baselines3** |
| --------------------------- | ----------------------|
Expand All @@ -41,7 +43,13 @@ you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselin

### Planned features

Please take a look at the [Roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) and [Milestones](https://github.com/DLR-RM/stable-baselines3/milestones).
Since most of the features from the [original roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) have been implemented, there are no major changes planned for SB3, it is now *stable*.
If you want to contribute, you can search in the issues for the ones where [help is welcomed](https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted) and the other [proposed enhancements](https://github.com/DLR-RM/stable-baselines3/labels/enhancement).

While SB3 development is now focused on bug fixes and maintenance (doc update, user experience, ...), there is more active development going on in the associated repositories:
- newer algorithms are regularly added to the [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) repository
- faster variants are developed in the [SBX (SB3 + Jax)](https://github.com/araffin/sbx) repository
- the training framework for SB3, the RL Zoo, has an active [roadmap](https://github.com/DLR-RM/rl-baselines3-zoo/issues/299)

## Migration guide: from Stable-Baselines (SB2) to Stable-Baselines3 (SB3)

Expand Down Expand Up @@ -79,7 +87,7 @@ Documentation: https://rl-baselines3-zoo.readthedocs.io/en/master/

We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)

This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).
This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), CrossQ, Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).

Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/)

Expand All @@ -97,17 +105,16 @@ It provides a minimal number of features compared to SB3 but can be much faster
### Prerequisites
Stable Baselines3 requires Python 3.8+.

#### Windows 10
#### Windows

To install stable-baselines on Windows, please look at the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/install.html#prerequisites).


### Install using pip
Install the Stable Baselines3 package:
```sh
pip install 'stable-baselines3[extra]'
```
pip install stable-baselines3[extra]
```
**Note:** Some shells such as Zsh require quotation marks around brackets, i.e. `pip install 'stable-baselines3[extra]'` ([More Info](https://stackoverflow.com/a/30539963)).

This includes an optional dependencies like Tensorboard, OpenCV or `ale-py` to train on atari games. If you do not need those, you can use:
```sh
Expand Down Expand Up @@ -177,6 +184,7 @@ All the following examples can be executed online using Google Colab notebooks:
| ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- |
| ARS<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| CrossQ<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| DDPG | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| DQN | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
Expand All @@ -191,7 +199,7 @@ All the following examples can be executed online using Google Colab notebooks:

<b id="f1">1</b>: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) GitHub repository.

Actions `gym.spaces`:
Actions `gymnasium.spaces`:
* `Box`: A N-dimensional box that contains every point in the action space.
* `Discrete`: A list of possible actions, where each timestep only one of the actions can be used.
* `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used.
Expand All @@ -218,9 +226,9 @@ To run a single test:
python3 -m pytest -v -k 'test_check_env_dict_action'
```

You can also do a static type check using `pytype` and `mypy`:
You can also do a static type check using `mypy`:
```sh
pip install pytype mypy
pip install mypy
make type
```

Expand Down Expand Up @@ -252,6 +260,8 @@ To cite this repository in publications:
}
```

Note: If you need to refer to a specific version of SB3, you can also use the [Zenodo DOI](https://doi.org/10.5281/zenodo.8123988).

## Maintainers

Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave), [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli) and [Quentin Gallouédec](https://gallouedec.com/) (@qgallouedec).
Expand Down
12 changes: 6 additions & 6 deletions docs/conda_env.yml
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
name: root
channels:
- pytorch
- defaults
- conda-forge
dependencies:
- cpuonly=1.0=0
- pip=22.3.1
- python=3.8
- pytorch=1.13.0=py3.8_cpu_0
- pip=24.2
- python=3.11
- pytorch=2.5.0=py3.11_cpu_0
- pip:
- gymnasium
- gymnasium>=0.28.1,<0.30
- cloudpickle
- opencv-python-headless
- pandas
- numpy
- numpy>=1.20,<2.0
- matplotlib
- sphinx>=5,<8
- sphinx_rtd_theme>=1.3.0
Expand Down
1 change: 1 addition & 0 deletions docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary``
=================== =========== ============ ================= =============== ================
ARS [#f1]_ ✔️ ✔️ ❌ ❌ ✔️
A2C ✔️ ✔️ ✔️ ✔️ ✔️
CrossQ [#f1]_ ✔️ ❌ ❌ ❌ ✔️
DDPG ✔️ ❌ ❌ ❌ ✔️
DQN ❌ ✔️ ❌ ❌ ✔️
HER ✔️ ✔️ ❌ ❌ ✔️
Expand Down
1 change: 1 addition & 0 deletions docs/guide/sb3_contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ See documentation for the full list of included features.
- `PPO with recurrent policy (RecurrentPPO aka PPO LSTM) <https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/>`_
- `Truncated Quantile Critics (TQC)`_
- `Trust Region Policy Optimization (TRPO) <https://arxiv.org/abs/1502.05477>`_
- `Batch Normalization in Deep Reinforcement Learning (CrossQ) <https://openreview.net/forum?id=PczQtTsTIX>`_


**Gym Wrappers**:
Expand Down
4 changes: 3 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,14 @@ To cite this project in publications:
url = {http://jmlr.org/papers/v22/20-1364.html}
}
Note: If you need to refer to a specific version of SB3, you can also use the `Zenodo DOI <https://doi.org/10.5281/zenodo.8123988>`_.

Contributing
------------

To any interested in making the rl baselines better, there are still some improvements
that need to be done.
You can check issues in the `repo <https://github.com/DLR-RM/stable-baselines3/issues>`_.
You can check issues in the `repository <https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted>`_.

If you want to contribute, please read `CONTRIBUTING.md <https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md>`_ first.

Expand Down
13 changes: 12 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
Changelog
==========

Release 2.4.0a9 (WIP)
Release 2.4.0a10 (WIP)
--------------------------

**New algorithm: CrossQ in SB3 Contrib**

.. note::

DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about
Expand Down Expand Up @@ -43,6 +45,10 @@ Bug Fixes:

`SB3-Contrib`_
^^^^^^^^^^^^^^
- Added ``CrossQ`` algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen)
- Added ``BatchRenorm`` PyTorch layer used in ``CrossQ`` (@danielpalen)
- Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger)
- Fixed loading QRDQN changes `target_update_interval` (@jak3122)

`RL Zoo`_
^^^^^^^^^
Expand All @@ -60,12 +66,17 @@ Others:
- Fixed various typos (@cschindlbeck)
- Remove unnecessary SDE noise resampling in PPO update (@brn-dev)
- Updated PyTorch version on CI to 2.3.1
- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy``
- Switched to uv to download packages faster on GitHub CI
- Updated dependencies for read the doc

Bug Fixes:
^^^^^^^^^^

Documentation:
^^^^^^^^^^^^^^
- Updated PPO doc to recommend using CPU with ``MlpPolicy``
- Clarified documentation about planned features and citing software

Release 2.3.2 (2024-04-27)
--------------------------
Expand Down
17 changes: 17 additions & 0 deletions docs/modules/ppo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,23 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments.
vec_env.render("human")
.. note::

PPO is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using ``SubprocVecEnv`` instead of the default ``DummyVecEnv``:

.. code-block::
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
if __name__=="__main__":
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
model = PPO("MlpPolicy", env, device="cpu")
model.learn(total_timesteps=25_000)
For more information, see :ref:`Vectorized Environments <vec_env>`, `Issue #1245 <https://github.com/DLR-RM/stable-baselines3/issues/1245#issuecomment-1435766949>`_ or the `Multiprocessing notebook <https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb>`_.

Results
-------

Expand Down
23 changes: 23 additions & 0 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import time
import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

import numpy as np
Expand Down Expand Up @@ -135,6 +136,28 @@ def _setup_model(self) -> None:
self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
)
self.policy = self.policy.to(self.device)
# Warn when not using CPU with MlpPolicy
self._maybe_recommend_cpu()

def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None:
"""
Recommend to use CPU only when using A2C/PPO with MlpPolicy.
:param: The name of the class for the default MlpPolicy.
"""
policy_class_name = self.policy_class.__name__
if self.device != th.device("cpu") and policy_class_name == mlp_class_name:
warnings.warn(
f"You are trying to run {self.__class__.__name__} on the GPU, "
"but it is primarily intended to run on the CPU when not using a CNN policy "
f"(you are using {policy_class_name} which should be a MlpPolicy). "
"See https://github.com/DLR-RM/stable-baselines3/issues/1245 "
"for more info. "
"You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU."
"Note: The model will train, but the GPU utilization will be poor and "
"the training might take longer than on CPU.",
UserWarning,
)

def collect_rollouts(
self,
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def set_random_seed(seed: int, using_cuda: bool = False) -> None:


# From stable baselines
def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> float:
"""
Computes fraction of variance that ypred explains about y.
Returns 1 - Var[y-ypred] / Var[y]
Expand All @@ -62,7 +62,7 @@ def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
"""
assert y_true.ndim == 1 and y_pred.ndim == 1
var_y = np.var(y_true)
return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
return np.nan if var_y == 0 else float(1 - np.var(y_true - y_pred) / var_y)


def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> None:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a9
2.4.0a10
14 changes: 12 additions & 2 deletions tests/test_run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import gymnasium as gym
import numpy as np
import pytest
import torch as th

from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.env_util import make_vec_env
Expand Down Expand Up @@ -211,8 +212,11 @@ def test_warn_dqn_multi_env():


def test_ppo_warnings():
"""Test that PPO warns and errors correctly on
problematic rollout buffer sizes"""
"""
Test that PPO warns and errors correctly on
problematic rollout buffer sizes,
and recommend using CPU.
"""

# Only 1 step: advantage normalization will return NaN
with pytest.raises(AssertionError):
Expand All @@ -234,3 +238,9 @@ def test_ppo_warnings():
loss = model.logger.name_to_value["train/loss"]
assert loss > 0
assert not np.isnan(loss) # check not nan (since nan does not equal nan)

with pytest.warns(UserWarning, match="You are trying to run PPO on the GPU"):
model = PPO("MlpPolicy", "Pendulum-v1")
# Pretend to be on the GPU
model.device = th.device("cuda")
model._maybe_recommend_cpu()
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_custom_vec_env(tmp_path):


@pytest.mark.parametrize("direct_policy", [False, True])
def test_evaluate_policy(direct_policy: bool):
def test_evaluate_policy(direct_policy):
model = A2C("MlpPolicy", "Pendulum-v1", seed=0)
n_steps_per_episode, n_eval_episodes = 200, 2

Expand Down

0 comments on commit 263e657

Please sign in to comment.