From a9273f968eaf8c6e04302a07d803eebfca6e7e86 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 12 Jan 2024 16:05:14 +0100 Subject: [PATCH] Update TD3/DDPG/DQN defaults for consistency (#1785) * Update TD3/DDPG/DQN defaults for consistency * Update changelog --- docs/misc/changelog.rst | 28 ++++++++++++++++++++++++++-- stable_baselines3/ddpg/ddpg.py | 6 +++--- stable_baselines3/dqn/dqn.py | 2 +- stable_baselines3/td3/td3.py | 6 +++--- stable_baselines3/version.txt | 2 +- 5 files changed, 34 insertions(+), 10 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index cbfe41f9d..a4d8e6373 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,12 +3,36 @@ Changelog ========== - -Release 2.3.0a0 (WIP) +Release 2.3.0a1 (WIP) -------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ +- The defaults hyperparameters of ``TD3`` and ``DDPG`` have been changed to be more consistent with ``SAC`` + +.. code-block:: python + + # SB3 < 2.3.0 default hyperparameters + # model = TD3("MlpPolicy", env, train_freq=(1, "episode"), gradient_steps=-1, batch_size=100) + # SB3 >= 2.3.0: + model = TD3("MlpPolicy", env, train_freq=1, gradient_steps=1, batch_size=256) + +.. note:: + + Two inconsistencies remains: the default network architecture for ``TD3/DDPG`` is ``[400, 300]`` instead of ``[256, 256]`` for SAC (for backward compatibility reasons, see `report on the influence of the network size `_) and the default learning rate is 1e-3 instead of 3e-4 for SAC (for performance reasons, see `W&B report on the influence of the lr `_) + + + +- The default ``leanrning_starts`` parameter of ``DQN`` have been changed to be consistent with the other offpolicy algorithms + + +.. code-block:: python + + # SB3 < 2.3.0 default hyperparameters, 50_000 corresponded to Atari defaults hyperparameters + # model = DQN("MlpPolicy", env, learning_start=50_000) + # SB3 >= 2.3.0: + model = DQN("MlpPolicy", env, learning_start=100) + New Features: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index c311b2357..2fe2fdfc4 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -60,11 +60,11 @@ def __init__( learning_rate: Union[float, Schedule] = 1e-3, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, - batch_size: int = 100, + batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = (1, "episode"), - gradient_steps: int = -1, + train_freq: Union[int, Tuple[int, str]] = 1, + gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 42e3d0df0..894ed9f04 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -79,7 +79,7 @@ def __init__( env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 1e-4, buffer_size: int = 1_000_000, # 1e6 - learning_starts: int = 50000, + learning_starts: int = 100, batch_size: int = 32, tau: float = 1.0, gamma: float = 0.99, diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index a06ce67e0..a61d954bc 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -83,11 +83,11 @@ def __init__( learning_rate: Union[float, Schedule] = 1e-3, buffer_size: int = 1_000_000, # 1e6 learning_starts: int = 100, - batch_size: int = 100, + batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = (1, "episode"), - gradient_steps: int = -1, + train_freq: Union[int, Tuple[int, str]] = 1, + gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, replay_buffer_class: Optional[Type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 00b35529e..4d04ad95c 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.3.0a0 +2.3.0a1