Skip to content

Commit

Permalink
Update changelog and SBX doc
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Apr 22, 2024
1 parent d35ee6a commit c1d8e60
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ Documentation is available online: [https://sb3-contrib.readthedocs.io/](https:/

## Stable-Baselines Jax (SBX)

[Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax.
[Stable Baselines Jax (SBX)](https://github.com/araffin/sbx) is a proof of concept version of Stable-Baselines3 in Jax, with recent algorithms like DroQ or CrossQ.

It provides a minimal number of features compared to SB3 but can be much faster (up to 20x times!): https://twitter.com/araffin2/status/1590714558628253698

Expand Down
3 changes: 2 additions & 1 deletion docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ Actions ``gym.spaces``:

.. note::

More algorithms (like QR-DQN or TQC) are implemented in our :ref:`contrib repo <sb3_contrib>`.
More algorithms (like QR-DQN or TQC) are implemented in our :ref:`contrib repo <sb3_contrib>`
and in our :ref:`SBX (SB3 + Jax) repo <sbx>` (DroQ, CrossQ, ...).

.. note::

Expand Down
15 changes: 9 additions & 6 deletions docs/guide/sbx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Implemented algorithms:
- Deep Q Network (DQN)
- Twin Delayed DDPG (TD3)
- Deep Deterministic Policy Gradient (DDPG)
- Batch Normalization in Deep Reinforcement Learning (CrossQ)


As SBX follows SB3 API, it is also compatible with the `RL Zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_.
Expand All @@ -29,16 +30,17 @@ For that you will need to create two files:
import rl_zoo3
import rl_zoo3.train
from rl_zoo3.train import train
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ
rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
rl_zoo3.ALGOS["droq"] = DroQ
# See SBX readme to use DroQ configuration
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.train.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS
Expand All @@ -56,16 +58,17 @@ Then you can call ``python train_sbx.py --algo sac --env Pendulum-v1`` and use t
import rl_zoo3
import rl_zoo3.enjoy
from rl_zoo3.enjoy import enjoy
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ
rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
rl_zoo3.ALGOS["droq"] = DroQ
# See SBX readme to use DroQ configuration
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.enjoy.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS
Expand Down
13 changes: 12 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@
Changelog
==========

Release 2.3.1 (2024-04-22)
--------------------------

Bug Fixes:
^^^^^^^^^^
- Cast return value of learning rate schedule to float, to avoid issue when loading model because of ``weights_only=True`` (@markscsmith)

Documentation:
^^^^^^^^^^^^^^
- Updated SBX documentation (CrossQ and deprecated DroQ)


Release 2.3.0 (2024-03-31)
--------------------------

Expand Down Expand Up @@ -48,7 +60,6 @@ New Features:
Bug Fixes:
^^^^^^^^^^
- Fixed ``monitor_wrapper`` argument that was not passed to the parent class, and dones argument that wasn't passed to ``_update_into_buffer`` (@corentinlger)
- Fixed ``learning_rate`` argument that could cause weights_only=True to fail if passed a function with non-float types (e.g. ``learning_rate=lambda _: np.sin(1.0)``) (@markscsmith)

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down

0 comments on commit c1d8e60

Please sign in to comment.