Skip to content

Commit

Permalink
Update SB3 ONNX export documentation (#1816)
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin authored Jan 30, 2024
1 parent a9273f9 commit 620e58e
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 38 deletions.
85 changes: 48 additions & 37 deletions docs/guide/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,53 +31,52 @@ to do inference in another framework.
Export to ONNX
-----------------

As of June 2021, ONNX format `doesn't support <https://github.com/onnx/onnx/issues/3033>`_ exporting models that use the ``broadcast_tensors`` functionality of pytorch. So in order to export the trained stable-baseline3 models in the ONNX format, we need to first remove the layers that use broadcasting. This can be done by creating a class that removes the unsupported layers.

The following examples are for ``MlpPolicy`` only, and are general examples. Note that you have to preprocess the observation the same way stable-baselines3 agent does (see ``common.preprocessing.preprocess_obs``).
If you are using PyTorch 2.0+ and ONNX Opset 14+, you can easily export SB3 policies using the following code:

For PPO, assuming a shared feature extractor.

.. warning::

The following example is for continuous actions only.
When using discrete or binary actions, you must do some `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/f3a35aa786ee41ffff599b99fa1607c067e89074/stable_baselines3/common/policies.py#L621-L637>`_
to obtain the action (e.g., convert action logits to action).
The following returns normalized actions and doesn't include the `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/a9273f968eaf8c6e04302a07d803eebfca6e7e86/stable_baselines3/common/policies.py#L370-L377>`_ step that is done with continuous actions
(clip or unscale the action to the correct space).


.. code-block:: python
import torch as th
from typing import Tuple
from stable_baselines3 import PPO
from stable_baselines3.common.policies import BasePolicy
class OnnxablePolicy(th.nn.Module):
def __init__(self, extractor, action_net, value_net):
class OnnxableSB3Policy(th.nn.Module):
def __init__(self, policy: BasePolicy):
super().__init__()
self.extractor = extractor
self.action_net = action_net
self.value_net = value_net
self.policy = policy
def forward(self, observation):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
action_hidden, value_hidden = self.extractor(observation)
return self.action_net(action_hidden), self.value_net(value_hidden)
def forward(self, observation: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
# NOTE: Preprocessing is included, but postprocessing
# (clipping/inscaling actions) is not,
# If needed, you also need to transpose the images so that they are channel first
# use deterministic=False if you want to export the stochastic policy
# policy() returns `actions, values, log_prob` for PPO
return self.policy(observation, deterministic=True)
# Example: model = PPO("MlpPolicy", "Pendulum-v1")
PPO("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel")
model = PPO.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(
model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net
)
onnx_policy = OnnxableSB3Policy(model.policy)
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
onnxable_model,
onnx_policy,
dummy_input,
"my_ppo_model.onnx",
opset_version=9,
opset_version=17,
input_names=["input"],
)
Expand All @@ -93,7 +92,13 @@ For PPO, assuming a shared feature extractor.
observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
action, value = ort_sess.run(None, {"input": observation})
actions, values, log_prob = ort_sess.run(None, {"input": observation})
print(actions, values, log_prob)
# Check that the predictions are the same
with th.no_grad():
print(model.policy(th.as_tensor(observation), deterministic=True))
For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies.
Expand All @@ -108,23 +113,16 @@ For SAC the procedure is similar. The example shown only exports the actor netwo
class OnnxablePolicy(th.nn.Module):
def __init__(self, actor: th.nn.Module):
super().__init__()
# Removing the flatten layer because it can't be onnxed
self.actor = th.nn.Sequential(
actor.latent_pi,
actor.mu,
# For gSDE
# th.nn.Hardtanh(min_val=-actor.clip_mean, max_val=actor.clip_mean),
# Squash the output
th.nn.Tanh(),
)
self.actor = actor
def forward(self, observation: th.Tensor) -> th.Tensor:
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
return self.actor(observation)
# NOTE: You may have to postprocess (unnormalize) actions
# to the correct bounds (see commented code below)
return self.actor(observation, deterministic=True)
# Example: model = SAC("MlpPolicy", "Pendulum-v1")
SAC("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel.zip")
model = SAC.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(model.policy.actor)
Expand All @@ -134,7 +132,7 @@ For SAC the procedure is similar. The example shown only exports the actor netwo
onnxable_model,
dummy_input,
"my_sac_actor.onnx",
opset_version=9,
opset_version=17,
input_names=["input"],
)
Expand All @@ -147,10 +145,23 @@ For SAC the procedure is similar. The example shown only exports the actor netwo
observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
action = ort_sess.run(None, {"input": observation})
scaled_action = ort_sess.run(None, {"input": observation})[0]
print(scaled_action)
# Post-process: rescale to correct space
# Rescale the action from [-1, 1] to [low, high]
# low, high = model.action_space.low, model.action_space.high
# post_processed_action = low + (0.5 * (scaled_action + 1.0) * (high - low))
# Check that the predictions are the same
with th.no_grad():
print(model.actor(th.as_tensor(observation), deterministic=True))
For more discussion around the topic, please refer to `GH#383 <https://github.com/DLR-RM/stable-baselines3/issues/383>`_ and `GH#1349 <https://github.com/DLR-RM/stable-baselines3/issues/1349>`_.


For more discussion around the topic refer to this `issue. <https://github.com/DLR-RM/stable-baselines3/issues/383>`_

Trace/Export to C++
-------------------
Expand Down
8 changes: 7 additions & 1 deletion docs/guide/rl_tips.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,12 @@ A better solution would be to use a squashing function (cf ``SAC``) or a Beta di
Tips and Tricks when implementing an RL algorithm
=================================================

.. note::

We have a `video on YouTube about reliable RL <https://www.youtube.com/watch?v=7-PUg9EAa3Y>`_ that covers
this section in more details. You can also find the `slides online <https://araffin.github.io/slides/tips-reliable-rl/>`_.


When you try to reproduce a RL paper by implementing the algorithm, the `nuts and bolts of RL research <http://joschu.net/docs/nuts-and-bolts.pdf>`_
by John Schulman are quite useful (`video <https://www.youtube.com/watch?v=8EcdaCk9KaQ>`_).

Expand Down Expand Up @@ -282,4 +288,4 @@ in RL with discrete actions:
3. Pong (one of the easiest Atari game)
4. other Atari games (e.g. Breakout)

.. _SBX: https://github.com/araffin/sbx
.. _SBX: https://github.com/araffin/sbx
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ Documentation:
^^^^^^^^^^^^^^
- Added a paragraph on modifying vectorized environment parameters via setters (@fracapuano)
- Updated callback code example
- Updated export to ONNX documentation, it is now much simpler to export SB3 models with newer ONNX Opset!
- Added video link to "Practical Tips for Reliable Reinforcement Learning" video

Release 2.2.1 (2023-11-17)
--------------------------
Expand Down

0 comments on commit 620e58e

Please sign in to comment.