Skip to content

Commit

Permalink
Add test and update changelog
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Sep 12, 2023
1 parent 5af53ce commit 0fe8704
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 6 deletions.
11 changes: 9 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
Changelog
==========

Release 2.2.0a1 (WIP)
Release 2.2.0a2 (WIP)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version

New Features:
^^^^^^^^^^^^^
Expand All @@ -18,14 +19,19 @@ New Features:
`RL Zoo`_
^^^^^^^^^

`SBX`_
^^^^^^^^^
- Added ``DDPG`` and ``TD3``

Bug Fixes:
^^^^^^^^^^
- Prevents using squash_output and not use_sde in ActorCritcPolicy (@PatrickHelm)
- Performs unscaling of actions in collect_rollout in OnPolicyAlgorithm (@PatrickHelm)
- Moves VectorizedActionNoise into ``_setup_learn()`` in OffPolicyAlgorithm (@PatrickHelm)
- Prevents out of bound error on Windows if no seed is passed (@PatrickHelm)
- Calls ``callback.update_locals()`` before ``callback.on_rollout_end()`` in OnPolicyAlgorithm (@PatrickHelm)
- Fixes replay buffer device after loading in OffPolicyAlgorithm (@PatrickHelm)
- Fixed replay buffer device after loading in OffPolicyAlgorithm (@PatrickHelm)
- Fixed ``render_mode`` which was not properly loaded when using ``VecNormalize.load()``


Deprecations:
Expand Down Expand Up @@ -1424,6 +1430,7 @@ and `Quentin Gallouédec`_ (aka @qgallouedec).

.. _SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
.. _RL Zoo: https://github.com/DLR-RM/rl-baselines3-zoo
.. _SBX: https://github.com/araffin/sbx

Contributors:
-------------
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.0a1
2.2.0a2
23 changes: 20 additions & 3 deletions tests/test_vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def make_env():
return Monitor(gym.make(ENV_ID))


def make_env_render():
return Monitor(gym.make(ENV_ID, render_mode="rgb_array"))


def make_dict_env():
return Monitor(DummyDictEnv())

Expand Down Expand Up @@ -257,14 +261,17 @@ def test_obs_rms_vec_normalize():
assert np.allclose(env.ret_rms.mean, 5.688, atol=1e-3)


@pytest.mark.parametrize("make_env", [make_env, make_dict_env, make_image_env])
def test_vec_env(tmp_path, make_env):
@pytest.mark.parametrize("make_gym_env", [make_env, make_dict_env, make_image_env])
def test_vec_env(tmp_path, make_gym_env):
"""Test VecNormalize Object"""
clip_obs = 0.5
clip_reward = 5.0

orig_venv = DummyVecEnv([make_env])
orig_venv = DummyVecEnv([make_gym_env])
norm_venv = VecNormalize(orig_venv, norm_obs=True, norm_reward=True, clip_obs=clip_obs, clip_reward=clip_reward)
assert orig_venv.render_mode is None
assert norm_venv.render_mode is None

_, done = norm_venv.reset(), [False]
while not done[0]:
actions = [norm_venv.action_space.sample()]
Expand All @@ -278,9 +285,19 @@ def test_vec_env(tmp_path, make_env):

path = tmp_path / "vec_normalize"
norm_venv.save(path)
assert orig_venv.render_mode is None
deserialized = VecNormalize.load(path, venv=orig_venv)
assert deserialized.render_mode is None
check_vec_norm_equal(norm_venv, deserialized)

# Check that render mode is properly updated
vec_env = DummyVecEnv([make_env_render])
assert vec_env.render_mode == "rgb_array"
# Test that loading and wrapping keep the correct render mode
if make_gym_env == make_env:
assert VecNormalize.load(path, venv=vec_env).render_mode == "rgb_array"
assert VecNormalize(vec_env).render_mode == "rgb_array"


def test_get_original():
venv = _make_warmstart_cartpole()
Expand Down

0 comments on commit 0fe8704

Please sign in to comment.