Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix loading of optimizer with older DQN models #1978

Merged
merged 1 commit into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
Changelog
==========

Release 2.4.0a6 (WIP)
Release 2.4.0a7 (WIP)
--------------------------

.. note::

DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about
truncation of optimizer state when loaded with SB3 >= 2.4.0.
To suppress the warning, simply save the model again.
You can find more info in `PR #1963 <https://github.com/DLR-RM/stable-baselines3/pull/1963>`_

Breaking Changes:
^^^^^^^^^^^^^^^^^

Expand All @@ -28,9 +35,11 @@ Bug Fixes:

`RL Zoo`_
^^^^^^^^^
- Updated defaults hyperparameters for TQC/SAC for Swimmer-v4 (decrease gamma for more consistent results)

`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^
- Added CNN support for DQN

Deprecations:
^^^^^^^^^^^^^
Expand Down
27 changes: 25 additions & 2 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,13 +742,13 @@ def load( # noqa: C901
# put state_dicts back in place
model.set_parameters(params, exact_match=True, device=device)
except RuntimeError as e:
# Patch to load Policy saved using SB3 < 1.7.0
# Patch to load policies saved using SB3 < 1.7.0
# the error is probably due to old policy being loaded
# See https://github.com/DLR-RM/stable-baselines3/issues/1233
if "pi_features_extractor" in str(e) and "Missing key(s) in state_dict" in str(e):
model.set_parameters(params, exact_match=False, device=device)
warnings.warn(
"You are probably loading a model saved with SB3 < 1.7.0, "
"You are probably loading a A2C/PPO model saved with SB3 < 1.7.0, "
"we deactivated exact_match so you can save the model "
"again to avoid issues in the future "
"(see https://github.com/DLR-RM/stable-baselines3/issues/1233 for more info). "
Expand All @@ -757,6 +757,29 @@ def load( # noqa: C901
)
else:
raise e
except ValueError as e:
# Patch to load DQN policies saved using SB3 < 2.4.0
# The target network params are no longer in the optimizer
# See https://github.com/DLR-RM/stable-baselines3/pull/1963
saved_optim_params = params["policy.optimizer"]["param_groups"][0]["params"] # type: ignore[index]
n_params_saved = len(saved_optim_params)
n_params = len(model.policy.optimizer.param_groups[0]["params"])
if n_params_saved == 2 * n_params:
# Truncate to include only online network params
params["policy.optimizer"]["param_groups"][0]["params"] = saved_optim_params[:n_params] # type: ignore[index]

model.set_parameters(params, exact_match=True, device=device)
warnings.warn(
"You are probably loading a DQN model saved with SB3 < 2.4.0, "
"we truncated the optimizer state so you can save the model "
"again to avoid issues in the future "
"(see https://github.com/DLR-RM/stable-baselines3/pull/1963 for more info). "
f"Original error: {e} \n"
"Note: the model should still work fine, this only a warning."
)
else:
raise e

# put other pytorch variables back in place
if pytorch_variables is not None:
for name in pytorch_variables:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a6
2.4.0a7
14 changes: 13 additions & 1 deletion tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def test_save_load_env_cnn(tmp_path, model_class):
# clear file from os
os.remove(tmp_path / "test_save.zip")

# Check we can load models saved with SB3 < 1.7.0
# Check we can load A2C/PPO models saved with SB3 < 1.7.0
if model_class == A2C:
del model.policy.pi_features_extractor
model.save(tmp_path / "test_save")
Expand Down Expand Up @@ -809,3 +809,15 @@ def test_save_load_net_arch_none(tmp_path):
# None has been replaced by the default net arch
assert model.policy.net_arch is not None
os.remove(tmp_path / "ppo.zip")


def test_save_load_no_target_params(tmp_path):
# Check we can load DQN models saved with SB3 < 2.4.0
model = DQN("MlpPolicy", "CartPole-v1", buffer_size=10000, learning_starts=4)
env = model.get_env()
# Include target net params
model.policy.optimizer = th.optim.Adam(model.policy.parameters(), lr=0.001)
model.save(tmp_path / "test_save")
with pytest.warns(UserWarning):
DQN.load(str(tmp_path / "test_save.zip"), env=env).learn(20)
os.remove(tmp_path / "test_save.zip")
Loading