diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 31ff99d09..37a035478 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,9 +3,16 @@ Changelog ========== -Release 2.4.0a6 (WIP) +Release 2.4.0a7 (WIP) -------------------------- +.. note:: + + DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about + truncation of optimizer state when loaded with SB3 >= 2.4.0. + To suppress the warning, simply save the model again. + You can find more info in `PR #1963 `_ + Breaking Changes: ^^^^^^^^^^^^^^^^^ @@ -28,9 +35,11 @@ Bug Fixes: `RL Zoo`_ ^^^^^^^^^ +- Updated defaults hyperparameters for TQC/SAC for Swimmer-v4 (decrease gamma for more consistent results) `SBX`_ (SB3 + Jax) ^^^^^^^^^^^^^^^^^^ +- Added CNN support for DQN Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index b2c967405..e43955f94 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -742,13 +742,13 @@ def load( # noqa: C901 # put state_dicts back in place model.set_parameters(params, exact_match=True, device=device) except RuntimeError as e: - # Patch to load Policy saved using SB3 < 1.7.0 + # Patch to load policies saved using SB3 < 1.7.0 # the error is probably due to old policy being loaded # See https://github.com/DLR-RM/stable-baselines3/issues/1233 if "pi_features_extractor" in str(e) and "Missing key(s) in state_dict" in str(e): model.set_parameters(params, exact_match=False, device=device) warnings.warn( - "You are probably loading a model saved with SB3 < 1.7.0, " + "You are probably loading a A2C/PPO model saved with SB3 < 1.7.0, " "we deactivated exact_match so you can save the model " "again to avoid issues in the future " "(see https://github.com/DLR-RM/stable-baselines3/issues/1233 for more info). " @@ -757,6 +757,29 @@ def load( # noqa: C901 ) else: raise e + except ValueError as e: + # Patch to load DQN policies saved using SB3 < 2.4.0 + # The target network params are no longer in the optimizer + # See https://github.com/DLR-RM/stable-baselines3/pull/1963 + saved_optim_params = params["policy.optimizer"]["param_groups"][0]["params"] # type: ignore[index] + n_params_saved = len(saved_optim_params) + n_params = len(model.policy.optimizer.param_groups[0]["params"]) + if n_params_saved == 2 * n_params: + # Truncate to include only online network params + params["policy.optimizer"]["param_groups"][0]["params"] = saved_optim_params[:n_params] # type: ignore[index] + + model.set_parameters(params, exact_match=True, device=device) + warnings.warn( + "You are probably loading a DQN model saved with SB3 < 2.4.0, " + "we truncated the optimizer state so you can save the model " + "again to avoid issues in the future " + "(see https://github.com/DLR-RM/stable-baselines3/pull/1963 for more info). " + f"Original error: {e} \n" + "Note: the model should still work fine, this only a warning." + ) + else: + raise e + # put other pytorch variables back in place if pytorch_variables is not None: for name in pytorch_variables: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 464a5c4dc..f5230e413 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a6 +2.4.0a7 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 5dc6ca7bf..962088246 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -340,7 +340,7 @@ def test_save_load_env_cnn(tmp_path, model_class): # clear file from os os.remove(tmp_path / "test_save.zip") - # Check we can load models saved with SB3 < 1.7.0 + # Check we can load A2C/PPO models saved with SB3 < 1.7.0 if model_class == A2C: del model.policy.pi_features_extractor model.save(tmp_path / "test_save") @@ -809,3 +809,15 @@ def test_save_load_net_arch_none(tmp_path): # None has been replaced by the default net arch assert model.policy.net_arch is not None os.remove(tmp_path / "ppo.zip") + + +def test_save_load_no_target_params(tmp_path): + # Check we can load DQN models saved with SB3 < 2.4.0 + model = DQN("MlpPolicy", "CartPole-v1", buffer_size=10000, learning_starts=4) + env = model.get_env() + # Include target net params + model.policy.optimizer = th.optim.Adam(model.policy.parameters(), lr=0.001) + model.save(tmp_path / "test_save") + with pytest.warns(UserWarning): + DQN.load(str(tmp_path / "test_save.zip"), env=env).learn(20) + os.remove(tmp_path / "test_save.zip")