Stable-Baselines3 v2.3.0: New defaults hyperparameters for DDPG, TD3 and DQN
Warning
Because of weights_only=True
, this release breaks loading of policies when using PyTorch 1.13.
Please upgrade to PyTorch >= 2.0 or upgrade SB3 version (we reverted the change in SB3 2.3.2)
SB3 Contrib (more algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
RL Zoo3 (training framework): https://github.com/DLR-RM/rl-baselines3-zoo
Stable-Baselines Jax (SBX): https://github.com/araffin/sbx
To upgrade:
pip install stable_baselines3 sb3_contrib --upgrade
or simply (rl zoo depends on SB3 and SB3 contrib):
pip install rl_zoo3 --upgrade
Breaking Changes:
- The defaults hyperparameters of
TD3
andDDPG
have been changed to be more consistent withSAC
# SB3 < 2.3.0 default hyperparameters
# model = TD3("MlpPolicy", env, train_freq=(1, "episode"), gradient_steps=-1, batch_size=100)
# SB3 >= 2.3.0:
model = TD3("MlpPolicy", env, train_freq=1, gradient_steps=1, batch_size=256)
Note
Two inconsistencies remain: the default network architecture for TD3/DDPG
is [400, 300]
instead of [256, 256]
for SAC (for backward compatibility reasons, see report on the influence of the network size ) and the default learning rate is 1e-3 instead of 3e-4 for SAC (for performance reasons, see W&B report on the influence of the lr )
- The default
learning_starts
parameter ofDQN
have been changed to be consistent with the other offpolicy algorithms
# SB3 < 2.3.0 default hyperparameters, 50_000 corresponded to Atari defaults hyperparameters
# model = DQN("MlpPolicy", env, learning_starts=50_000)
# SB3 >= 2.3.0:
model = DQN("MlpPolicy", env, learning_starts=100)
- For safety,
torch.load()
is now called withweights_only=True
when loading torch tensors,
policyload()
still usesweights_only=False
as gymnasium imports are required for it to work - When using
huggingface_sb3
, you will now need to setTRUST_REMOTE_CODE=True
when downloading models from the hub, aspickle.load
is not safe.
New Features:
- Log success rate
rollout/success_rate
when available for on policy algorithms (@corentinlger)
Bug Fixes:
- Fixed
monitor_wrapper
argument that was not passed to the parent class, and dones argument that wasn't passed to_update_into_buffer
(@corentinlger)
SB3-Contrib
- Added
rollout_buffer_class
androllout_buffer_kwargs
arguments to MaskablePPO - Fixed
train_freq
type annotation for tqc and qrdqn (@Armandpl) - Fixed
sb3_contrib/common/maskable/*.py
type annotations - Fixed
sb3_contrib/ppo_mask/ppo_mask.py
type annotations - Fixed
sb3_contrib/common/vec_env/async_eval.py
type annotations - Add some additional notes about
MaskablePPO
(evaluation and multi-process) (@icheered)
RL Zoo
- Updated defaults hyperparameters for TD3/DDPG to be more consistent with SAC
- Upgraded MuJoCo envs hyperparameters to v4 (pre-trained agents need to be updated)
- Added test dependencies to
setup.py
(@power-edge) - Simplify dependencies of
requirements.txt
(remove duplicates fromsetup.py
)
SBX (SB3 + Jax)
- Added support for
MultiDiscrete
andMultiBinary
action spaces to PPO - Added support for large values for gradient_steps to SAC, TD3, and TQC
- Fix
train()
signature and update type hints - Fix replay buffer device at load time
- Added flatten layer
- Added
CrossQ
Others:
- Updated black from v23 to v24
- Updated ruff to >= v0.3.1
- Updated env checker for (multi)discrete spaces with non-zero start.
Documentation:
- Added a paragraph on modifying vectorized environment parameters via setters (@fracapuano)
- Updated callback code example
- Updated export to ONNX documentation, it is now much simpler to export SB3 models with newer ONNX Opset!
- Added video link to "Practical Tips for Reliable Reinforcement Learning" video
- Added
render_mode="human"
in the README example (@marekm4) - Fixed docstring signature for sum_independent_dims (@StagOverflow)
- Updated docstring description for
log_interval
in the base class (@rushitnshah).
Full Changelog: v2.2.1...v2.3.0