Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update SB3 ONNX export documentation #1816

Merged
merged 1 commit into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 48 additions & 37 deletions docs/guide/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,53 +31,52 @@ to do inference in another framework.
Export to ONNX
-----------------

As of June 2021, ONNX format `doesn't support <https://github.com/onnx/onnx/issues/3033>`_ exporting models that use the ``broadcast_tensors`` functionality of pytorch. So in order to export the trained stable-baseline3 models in the ONNX format, we need to first remove the layers that use broadcasting. This can be done by creating a class that removes the unsupported layers.

The following examples are for ``MlpPolicy`` only, and are general examples. Note that you have to preprocess the observation the same way stable-baselines3 agent does (see ``common.preprocessing.preprocess_obs``).
If you are using PyTorch 2.0+ and ONNX Opset 14+, you can easily export SB3 policies using the following code:

For PPO, assuming a shared feature extractor.

.. warning::

The following example is for continuous actions only.
When using discrete or binary actions, you must do some `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/f3a35aa786ee41ffff599b99fa1607c067e89074/stable_baselines3/common/policies.py#L621-L637>`_
to obtain the action (e.g., convert action logits to action).
The following returns normalized actions and doesn't include the `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/a9273f968eaf8c6e04302a07d803eebfca6e7e86/stable_baselines3/common/policies.py#L370-L377>`_ step that is done with continuous actions
(clip or unscale the action to the correct space).


.. code-block:: python

import torch as th
from typing import Tuple

from stable_baselines3 import PPO
from stable_baselines3.common.policies import BasePolicy


class OnnxablePolicy(th.nn.Module):
def __init__(self, extractor, action_net, value_net):
class OnnxableSB3Policy(th.nn.Module):
def __init__(self, policy: BasePolicy):
super().__init__()
self.extractor = extractor
self.action_net = action_net
self.value_net = value_net
self.policy = policy

def forward(self, observation):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
action_hidden, value_hidden = self.extractor(observation)
return self.action_net(action_hidden), self.value_net(value_hidden)
def forward(self, observation: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
# NOTE: Preprocessing is included, but postprocessing
# (clipping/inscaling actions) is not,
# If needed, you also need to transpose the images so that they are channel first
# use deterministic=False if you want to export the stochastic policy
# policy() returns `actions, values, log_prob` for PPO
return self.policy(observation, deterministic=True)


# Example: model = PPO("MlpPolicy", "Pendulum-v1")
PPO("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel")
model = PPO.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(
model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net
)

onnx_policy = OnnxableSB3Policy(model.policy)

observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
onnxable_model,
onnx_policy,
dummy_input,
"my_ppo_model.onnx",
opset_version=9,
opset_version=17,
input_names=["input"],
)

Expand All @@ -93,7 +92,13 @@ For PPO, assuming a shared feature extractor.

observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
action, value = ort_sess.run(None, {"input": observation})
actions, values, log_prob = ort_sess.run(None, {"input": observation})

print(actions, values, log_prob)

# Check that the predictions are the same
with th.no_grad():
print(model.policy(th.as_tensor(observation), deterministic=True))


For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies.
Expand All @@ -108,23 +113,16 @@ For SAC the procedure is similar. The example shown only exports the actor netwo
class OnnxablePolicy(th.nn.Module):
def __init__(self, actor: th.nn.Module):
super().__init__()
# Removing the flatten layer because it can't be onnxed
self.actor = th.nn.Sequential(
actor.latent_pi,
actor.mu,
# For gSDE
# th.nn.Hardtanh(min_val=-actor.clip_mean, max_val=actor.clip_mean),
# Squash the output
th.nn.Tanh(),
)
self.actor = actor

def forward(self, observation: th.Tensor) -> th.Tensor:
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
return self.actor(observation)
# NOTE: You may have to postprocess (unnormalize) actions
# to the correct bounds (see commented code below)
return self.actor(observation, deterministic=True)


# Example: model = SAC("MlpPolicy", "Pendulum-v1")
SAC("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel.zip")
model = SAC.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(model.policy.actor)

Expand All @@ -134,7 +132,7 @@ For SAC the procedure is similar. The example shown only exports the actor netwo
onnxable_model,
dummy_input,
"my_sac_actor.onnx",
opset_version=9,
opset_version=17,
input_names=["input"],
)

Expand All @@ -147,10 +145,23 @@ For SAC the procedure is similar. The example shown only exports the actor netwo

observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
action = ort_sess.run(None, {"input": observation})
scaled_action = ort_sess.run(None, {"input": observation})[0]

print(scaled_action)

# Post-process: rescale to correct space
# Rescale the action from [-1, 1] to [low, high]
# low, high = model.action_space.low, model.action_space.high
# post_processed_action = low + (0.5 * (scaled_action + 1.0) * (high - low))

# Check that the predictions are the same
with th.no_grad():
print(model.actor(th.as_tensor(observation), deterministic=True))


For more discussion around the topic, please refer to `GH#383 <https://github.com/DLR-RM/stable-baselines3/issues/383>`_ and `GH#1349 <https://github.com/DLR-RM/stable-baselines3/issues/1349>`_.


For more discussion around the topic refer to this `issue. <https://github.com/DLR-RM/stable-baselines3/issues/383>`_

Trace/Export to C++
-------------------
Expand Down
8 changes: 7 additions & 1 deletion docs/guide/rl_tips.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,12 @@ A better solution would be to use a squashing function (cf ``SAC``) or a Beta di
Tips and Tricks when implementing an RL algorithm
=================================================

.. note::

We have a `video on YouTube about reliable RL <https://www.youtube.com/watch?v=7-PUg9EAa3Y>`_ that covers
this section in more details. You can also find the `slides online <https://araffin.github.io/slides/tips-reliable-rl/>`_.


When you try to reproduce a RL paper by implementing the algorithm, the `nuts and bolts of RL research <http://joschu.net/docs/nuts-and-bolts.pdf>`_
by John Schulman are quite useful (`video <https://www.youtube.com/watch?v=8EcdaCk9KaQ>`_).

Expand Down Expand Up @@ -282,4 +288,4 @@ in RL with discrete actions:
3. Pong (one of the easiest Atari game)
4. other Atari games (e.g. Breakout)

.. _SBX: https://github.com/araffin/sbx
.. _SBX: https://github.com/araffin/sbx
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ Documentation:
^^^^^^^^^^^^^^
- Added a paragraph on modifying vectorized environment parameters via setters (@fracapuano)
- Updated callback code example
- Updated export to ONNX documentation, it is now much simpler to export SB3 models with newer ONNX Opset!
- Added video link to "Practical Tips for Reliable Reinforcement Learning" video

Release 2.2.1 (2023-11-17)
--------------------------
Expand Down
Loading