diff --git a/docs/guide/export.rst b/docs/guide/export.rst index cccf30014..88a02fe8b 100644 --- a/docs/guide/export.rst +++ b/docs/guide/export.rst @@ -31,53 +31,52 @@ to do inference in another framework. Export to ONNX ----------------- -As of June 2021, ONNX format `doesn't support `_ exporting models that use the ``broadcast_tensors`` functionality of pytorch. So in order to export the trained stable-baseline3 models in the ONNX format, we need to first remove the layers that use broadcasting. This can be done by creating a class that removes the unsupported layers. -The following examples are for ``MlpPolicy`` only, and are general examples. Note that you have to preprocess the observation the same way stable-baselines3 agent does (see ``common.preprocessing.preprocess_obs``). +If you are using PyTorch 2.0+ and ONNX Opset 14+, you can easily export SB3 policies using the following code: -For PPO, assuming a shared feature extractor. .. warning:: - The following example is for continuous actions only. - When using discrete or binary actions, you must do some `post-processing `_ - to obtain the action (e.g., convert action logits to action). + The following returns normalized actions and doesn't include the `post-processing `_ step that is done with continuous actions + (clip or unscale the action to the correct space). .. code-block:: python import torch as th + from typing import Tuple from stable_baselines3 import PPO + from stable_baselines3.common.policies import BasePolicy - class OnnxablePolicy(th.nn.Module): - def __init__(self, extractor, action_net, value_net): + class OnnxableSB3Policy(th.nn.Module): + def __init__(self, policy: BasePolicy): super().__init__() - self.extractor = extractor - self.action_net = action_net - self.value_net = value_net + self.policy = policy - def forward(self, observation): - # NOTE: You may have to process (normalize) observation in the correct - # way before using this. See `common.preprocessing.preprocess_obs` - action_hidden, value_hidden = self.extractor(observation) - return self.action_net(action_hidden), self.value_net(value_hidden) + def forward(self, observation: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + # NOTE: Preprocessing is included, but postprocessing + # (clipping/inscaling actions) is not, + # If needed, you also need to transpose the images so that they are channel first + # use deterministic=False if you want to export the stochastic policy + # policy() returns `actions, values, log_prob` for PPO + return self.policy(observation, deterministic=True) # Example: model = PPO("MlpPolicy", "Pendulum-v1") + PPO("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel") model = PPO.load("PathToTrainedModel.zip", device="cpu") - onnxable_model = OnnxablePolicy( - model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net - ) + + onnx_policy = OnnxableSB3Policy(model.policy) observation_size = model.observation_space.shape dummy_input = th.randn(1, *observation_size) th.onnx.export( - onnxable_model, + onnx_policy, dummy_input, "my_ppo_model.onnx", - opset_version=9, + opset_version=17, input_names=["input"], ) @@ -93,7 +92,13 @@ For PPO, assuming a shared feature extractor. observation = np.zeros((1, *observation_size)).astype(np.float32) ort_sess = ort.InferenceSession(onnx_path) - action, value = ort_sess.run(None, {"input": observation}) + actions, values, log_prob = ort_sess.run(None, {"input": observation}) + + print(actions, values, log_prob) + + # Check that the predictions are the same + with th.no_grad(): + print(model.policy(th.as_tensor(observation), deterministic=True)) For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies. @@ -108,23 +113,16 @@ For SAC the procedure is similar. The example shown only exports the actor netwo class OnnxablePolicy(th.nn.Module): def __init__(self, actor: th.nn.Module): super().__init__() - # Removing the flatten layer because it can't be onnxed - self.actor = th.nn.Sequential( - actor.latent_pi, - actor.mu, - # For gSDE - # th.nn.Hardtanh(min_val=-actor.clip_mean, max_val=actor.clip_mean), - # Squash the output - th.nn.Tanh(), - ) + self.actor = actor def forward(self, observation: th.Tensor) -> th.Tensor: - # NOTE: You may have to process (normalize) observation in the correct - # way before using this. See `common.preprocessing.preprocess_obs` - return self.actor(observation) + # NOTE: You may have to postprocess (unnormalize) actions + # to the correct bounds (see commented code below) + return self.actor(observation, deterministic=True) # Example: model = SAC("MlpPolicy", "Pendulum-v1") + SAC("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel.zip") model = SAC.load("PathToTrainedModel.zip", device="cpu") onnxable_model = OnnxablePolicy(model.policy.actor) @@ -134,7 +132,7 @@ For SAC the procedure is similar. The example shown only exports the actor netwo onnxable_model, dummy_input, "my_sac_actor.onnx", - opset_version=9, + opset_version=17, input_names=["input"], ) @@ -147,10 +145,23 @@ For SAC the procedure is similar. The example shown only exports the actor netwo observation = np.zeros((1, *observation_size)).astype(np.float32) ort_sess = ort.InferenceSession(onnx_path) - action = ort_sess.run(None, {"input": observation}) + scaled_action = ort_sess.run(None, {"input": observation})[0] + + print(scaled_action) + + # Post-process: rescale to correct space + # Rescale the action from [-1, 1] to [low, high] + # low, high = model.action_space.low, model.action_space.high + # post_processed_action = low + (0.5 * (scaled_action + 1.0) * (high - low)) + + # Check that the predictions are the same + with th.no_grad(): + print(model.actor(th.as_tensor(observation), deterministic=True)) + + +For more discussion around the topic, please refer to `GH#383 `_ and `GH#1349 `_. -For more discussion around the topic refer to this `issue. `_ Trace/Export to C++ ------------------- diff --git a/docs/guide/rl_tips.rst b/docs/guide/rl_tips.rst index ce6f43e55..ae37640c7 100644 --- a/docs/guide/rl_tips.rst +++ b/docs/guide/rl_tips.rst @@ -252,6 +252,12 @@ A better solution would be to use a squashing function (cf ``SAC``) or a Beta di Tips and Tricks when implementing an RL algorithm ================================================= +.. note:: + + We have a `video on YouTube about reliable RL `_ that covers + this section in more details. You can also find the `slides online `_. + + When you try to reproduce a RL paper by implementing the algorithm, the `nuts and bolts of RL research `_ by John Schulman are quite useful (`video `_). @@ -282,4 +288,4 @@ in RL with discrete actions: 3. Pong (one of the easiest Atari game) 4. other Atari games (e.g. Breakout) -.. _SBX: https://github.com/araffin/sbx \ No newline at end of file +.. _SBX: https://github.com/araffin/sbx diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index a4d8e6373..4f22c672b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -59,6 +59,8 @@ Documentation: ^^^^^^^^^^^^^^ - Added a paragraph on modifying vectorized environment parameters via setters (@fracapuano) - Updated callback code example +- Updated export to ONNX documentation, it is now much simpler to export SB3 models with newer ONNX Opset! +- Added video link to "Practical Tips for Reliable Reinforcement Learning" video Release 2.2.1 (2023-11-17) --------------------------