Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
iwishiwasaneagle authored Jul 30, 2024
2 parents 5121148 + 6ad6fa5 commit 880bd2d
Show file tree
Hide file tree
Showing 55 changed files with 870 additions and 241 deletions.
72 changes: 37 additions & 35 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ name: CI

on:
push:
branches: [ master ]
branches: [master]
pull_request:
branches: [ master ]
branches: [master]

jobs:
build:
Expand All @@ -23,38 +23,40 @@ jobs:
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
# Install Atari Roms
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
# Install Atari Roms
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
pip install .[extra_no_roms,tests,docs]
# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
run: |
make lint
- name: Build the doc
run: |
make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
- name: Test with pytest
run: |
make pytest
pip install .[extra_no_roms,tests,docs]
# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
run: |
make lint
- name: Build the doc
run: |
make doc
- name: Check codestyle
run: |
make check-codestyle
- name: Type check
run: |
make type
# Do not run for python 3.8 (mypy internal error)
if: matrix.python-version != '3.8'
- name: Test with pytest
run: |
make pytest
2 changes: 1 addition & 1 deletion CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
identity and expression, level of experience, education, socioeconomic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.

Expand Down
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@ type: mypy
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://www.flake8rules.com/
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff ${LINT_PATHS} --exit-zero
ruff check ${LINT_PATHS} --exit-zero --output-format=concise

format:
# Sort imports
ruff --select I ${LINT_PATHS} --fix
ruff check --select I ${LINT_PATHS} --fix
# Reformat using black
black ${LINT_PATHS}

check-codestyle:
# Sort imports
ruff --select I ${LINT_PATHS}
ruff check --select I ${LINT_PATHS}
# Reformat using black
black --check ${LINT_PATHS}

Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ Documentation is available online: [https://sb3-contrib.readthedocs.io/](https:/

## Stable-Baselines Jax (SBX)

[Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax.
[Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax, with recent algorithms like DroQ or CrossQ.

It provides a minimal number of features compared to SB3 but can be much faster (up to 20x times!): https://twitter.com/araffin2/status/1590714558628253698

Expand Down Expand Up @@ -127,7 +127,7 @@ import gymnasium as gym

from stable_baselines3 import PPO

env = gym.make("CartPole-v1")
env = gym.make("CartPole-v1", render_mode="human")

model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)
Expand Down Expand Up @@ -192,7 +192,7 @@ All the following examples can be executed online using Google Colab notebooks:
<b id="f1">1</b>: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) GitHub repository.

Actions `gym.spaces`:
* `Box`: A N-dimensional box that containes every point in the action space.
* `Box`: A N-dimensional box that contains every point in the action space.
* `Discrete`: A list of possible actions, where each timestep only one of the actions can be used.
* `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used.
* `MultiBinary`: A list of possible actions, where each timestep any of the actions can be used in any combination.
Expand Down
3 changes: 2 additions & 1 deletion docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ Actions ``gym.spaces``:

.. note::

More algorithms (like QR-DQN or TQC) are implemented in our :ref:`contrib repo <sb3_contrib>`.
More algorithms (like QR-DQN or TQC) are implemented in our :ref:`contrib repo <sb3_contrib>`
and in our :ref:`SBX (SB3 + Jax) repo <sbx>` (DroQ, CrossQ, ...).

.. note::

Expand Down
15 changes: 8 additions & 7 deletions docs/guide/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,25 @@ You can find two examples of custom callbacks in the documentation: one for savi
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""
def __init__(self, verbose=0):
def __init__(self, verbose: int = 0):
super().__init__(verbose)
# Those variables will be accessible in the callback
# (they are defined in the base class)
# The RL model
# self.model = None # type: BaseAlgorithm
# An alias for self.model.get_env(), the environment used for training
# self.training_env = None # type: Union[gym.Env, VecEnv, None]
# self.training_env # type: VecEnv
# Number of time the callback was called
# self.n_calls = 0 # type: int
# num_timesteps = n_envs * n times env.step() was called
# self.num_timesteps = 0 # type: int
# local and global variables
# self.locals = None # type: Dict[str, Any]
# self.globals = None # type: Dict[str, Any]
# self.locals = {} # type: Dict[str, Any]
# self.globals = {} # type: Dict[str, Any]
# The logger object, used to report things in the terminal
# self.logger = None # stable_baselines3.common.logger
# # Sometimes, for event callback, it is useful
# # to have access to the parent object
# self.logger # type: stable_baselines3.common.logger.Logger
# Sometimes, for event callback, it is useful
# to have access to the parent object
# self.parent = None # type: Optional[BaseCallback]
def _on_training_start(self) -> None:
Expand Down
10 changes: 5 additions & 5 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
:param env_id: the environment ID
:param num_env: the number of environments you wish to have in subprocesses
:param seed: the inital seed for RNG
:param seed: the initial seed for RNG
:param rank: index of the subprocess
"""
def _init():
Expand Down Expand Up @@ -179,9 +179,9 @@ Multiprocessing with off-policy algorithms
vec_env = make_vec_env("Pendulum-v0", n_envs=4, seed=0)
# We collect 4 transitions per call to `ènv.step()`
# and performs 2 gradient steps per call to `ènv.step()`
# if gradient_steps=-1, then we would do 4 gradients steps per call to `ènv.step()`
# We collect 4 transitions per call to `env.step()`
# and performs 2 gradient steps per call to `env.step()`
# if gradient_steps=-1, then we would do 4 gradients steps per call to `env.step()`
model = SAC("MlpPolicy", vec_env, train_freq=1, gradient_steps=2, verbose=1)
model.learn(total_timesteps=10_000)
Expand Down Expand Up @@ -436,7 +436,7 @@ will compute a running average and standard deviation of input features (it can
log_dir = "/tmp/"
model.save(log_dir + "ppo_halfcheetah")
stats_path = os.path.join(log_dir, "vec_normalize.pkl")
env.save(stats_path)
vec_env.save(stats_path)
# To demonstrate loading
del model, vec_env
Expand Down
85 changes: 48 additions & 37 deletions docs/guide/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,53 +31,52 @@ to do inference in another framework.
Export to ONNX
-----------------

As of June 2021, ONNX format `doesn't support <https://github.com/onnx/onnx/issues/3033>`_ exporting models that use the ``broadcast_tensors`` functionality of pytorch. So in order to export the trained stable-baseline3 models in the ONNX format, we need to first remove the layers that use broadcasting. This can be done by creating a class that removes the unsupported layers.

The following examples are for ``MlpPolicy`` only, and are general examples. Note that you have to preprocess the observation the same way stable-baselines3 agent does (see ``common.preprocessing.preprocess_obs``).
If you are using PyTorch 2.0+ and ONNX Opset 14+, you can easily export SB3 policies using the following code:

For PPO, assuming a shared feature extractor.

.. warning::

The following example is for continuous actions only.
When using discrete or binary actions, you must do some `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/f3a35aa786ee41ffff599b99fa1607c067e89074/stable_baselines3/common/policies.py#L621-L637>`_
to obtain the action (e.g., convert action logits to action).
The following returns normalized actions and doesn't include the `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/a9273f968eaf8c6e04302a07d803eebfca6e7e86/stable_baselines3/common/policies.py#L370-L377>`_ step that is done with continuous actions
(clip or unscale the action to the correct space).


.. code-block:: python
import torch as th
from typing import Tuple
from stable_baselines3 import PPO
from stable_baselines3.common.policies import BasePolicy
class OnnxablePolicy(th.nn.Module):
def __init__(self, extractor, action_net, value_net):
class OnnxableSB3Policy(th.nn.Module):
def __init__(self, policy: BasePolicy):
super().__init__()
self.extractor = extractor
self.action_net = action_net
self.value_net = value_net
self.policy = policy
def forward(self, observation):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
action_hidden, value_hidden = self.extractor(observation)
return self.action_net(action_hidden), self.value_net(value_hidden)
def forward(self, observation: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
# NOTE: Preprocessing is included, but postprocessing
# (clipping/inscaling actions) is not,
# If needed, you also need to transpose the images so that they are channel first
# use deterministic=False if you want to export the stochastic policy
# policy() returns `actions, values, log_prob` for PPO
return self.policy(observation, deterministic=True)
# Example: model = PPO("MlpPolicy", "Pendulum-v1")
PPO("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel")
model = PPO.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(
model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net
)
onnx_policy = OnnxableSB3Policy(model.policy)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
onnxable_model,
onnx_policy,
dummy_input,
"my_ppo_model.onnx",
opset_version=9,
opset_version=17,
input_names=["input"],
)
Expand All @@ -93,7 +92,13 @@ For PPO, assuming a shared feature extractor.
observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
action, value = ort_sess.run(None, {"input": observation})
actions, values, log_prob = ort_sess.run(None, {"input": observation})
print(actions, values, log_prob)
# Check that the predictions are the same
with th.no_grad():
print(model.policy(th.as_tensor(observation), deterministic=True))
For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies.
Expand All @@ -108,23 +113,16 @@ For SAC the procedure is similar. The example shown only exports the actor netwo
class OnnxablePolicy(th.nn.Module):
def __init__(self, actor: th.nn.Module):
super().__init__()
# Removing the flatten layer because it can't be onnxed
self.actor = th.nn.Sequential(
actor.latent_pi,
actor.mu,
# For gSDE
# th.nn.Hardtanh(min_val=-actor.clip_mean, max_val=actor.clip_mean),
# Squash the output
th.nn.Tanh(),
)
self.actor = actor
def forward(self, observation: th.Tensor) -> th.Tensor:
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
return self.actor(observation)
# NOTE: You may have to postprocess (unnormalize) actions
# to the correct bounds (see commented code below)
return self.actor(observation, deterministic=True)
# Example: model = SAC("MlpPolicy", "Pendulum-v1")
SAC("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel.zip")
model = SAC.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(model.policy.actor)
Expand All @@ -134,7 +132,7 @@ For SAC the procedure is similar. The example shown only exports the actor netwo
onnxable_model,
dummy_input,
"my_sac_actor.onnx",
opset_version=9,
opset_version=17,
input_names=["input"],
)
Expand All @@ -147,10 +145,23 @@ For SAC the procedure is similar. The example shown only exports the actor netwo
observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
action = ort_sess.run(None, {"input": observation})
scaled_action = ort_sess.run(None, {"input": observation})[0]
print(scaled_action)
# Post-process: rescale to correct space
# Rescale the action from [-1, 1] to [low, high]
# low, high = model.action_space.low, model.action_space.high
# post_processed_action = low + (0.5 * (scaled_action + 1.0) * (high - low))
# Check that the predictions are the same
with th.no_grad():
print(model.actor(th.as_tensor(observation), deterministic=True))
For more discussion around the topic, please refer to `GH#383 <https://github.com/DLR-RM/stable-baselines3/issues/383>`_ and `GH#1349 <https://github.com/DLR-RM/stable-baselines3/issues/1349>`_.


For more discussion around the topic refer to this `issue. <https://github.com/DLR-RM/stable-baselines3/issues/383>`_

Trace/Export to C++
-------------------
Expand Down
Loading

0 comments on commit 880bd2d

Please sign in to comment.