Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix render_mode when loading VecNormalize #1671

Merged
merged 3 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pip install -e .[docs,tests,extra]

## Codestyle

We use [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports.
We use [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [ruff](https://github.com/astral-sh/ruff) (isort rules) to sort the imports.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking forward to check out ruff :-)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ruff is already used as a replacement for flake8 and it's pretty fast =)

For the documentation, we use the default line length of 88 characters per line.

**Please run `make format`** to reformat your code. You can check the codestyle using `make check-codestyle` and `make lint`.
Expand All @@ -63,7 +63,7 @@ def my_function(arg1: type1, arg2: type2) -> returntype:

Before proposing a PR, please open an issue, where the feature will be discussed. This prevent from duplicated PR to be proposed and also ease the code review process.

Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @ernestum, @AdamGleave or @Miffyli).
Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @ernestum, @AdamGleave, @Miffyli or @qgallouedec).
A PR must pass the Continuous Integration tests to be merged with the master branch.


Expand All @@ -85,7 +85,7 @@ Type checking with `pytype` and `mypy`:
make type
```

Codestyle check with `black`, `isort` and `ruff`:
Codestyle check with `black`, and `ruff` (`isort` rules):

```
make check-codestyle
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ lint:

format:
# Sort imports
isort ${LINT_PATHS}
ruff --select I ${LINT_PATHS} --fix
# Reformat using black
black ${LINT_PATHS}

check-codestyle:
# Sort imports
isort --check ${LINT_PATHS}
ruff --select I ${LINT_PATHS}
# Reformat using black
black --check ${LINT_PATHS}

Expand Down
11 changes: 9 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
Changelog
==========

Release 2.2.0a1 (WIP)
Release 2.2.0a2 (WIP)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Switched to ``ruff`` for sorting imports (isort is no longer needed), black and ruff version now require a minimum version

New Features:
^^^^^^^^^^^^^
Expand All @@ -18,14 +19,19 @@ New Features:
`RL Zoo`_
^^^^^^^^^

`SBX`_
^^^^^^^^^
- Added ``DDPG`` and ``TD3``

Bug Fixes:
^^^^^^^^^^
- Prevents using squash_output and not use_sde in ActorCritcPolicy (@PatrickHelm)
- Performs unscaling of actions in collect_rollout in OnPolicyAlgorithm (@PatrickHelm)
- Moves VectorizedActionNoise into ``_setup_learn()`` in OffPolicyAlgorithm (@PatrickHelm)
- Prevents out of bound error on Windows if no seed is passed (@PatrickHelm)
- Calls ``callback.update_locals()`` before ``callback.on_rollout_end()`` in OnPolicyAlgorithm (@PatrickHelm)
- Fixes replay buffer device after loading in OffPolicyAlgorithm (@PatrickHelm)
- Fixed replay buffer device after loading in OffPolicyAlgorithm (@PatrickHelm)
- Fixed ``render_mode`` which was not properly loaded when using ``VecNormalize.load()``


Deprecations:
Expand Down Expand Up @@ -1424,6 +1430,7 @@ and `Quentin Gallouédec`_ (aka @qgallouedec).

.. _SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
.. _RL Zoo: https://github.com/DLR-RM/rl-baselines3-zoo
.. _SBX: https://github.com/araffin/sbx

Contributors:
-------------
Expand Down
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@ max-complexity = 15
[tool.black]
line-length = 127

[tool.isort]
profile = "black"
line_length = 127
src_paths = ["stable_baselines3"]

[tool.pytype]
inputs = ["stable_baselines3"]
disable = ["pyi-error"]
Expand Down
8 changes: 3 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,10 @@
# Type check
"pytype",
"mypy",
# Lint code (flake8 replacement)
"ruff",
# Sort imports
"isort>=5.0",
# Lint code and sort imports (flake8 and isort replacement)
"ruff>=0.0.288",
# Reformat
"black",
"black>=23.9.1,<24",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw black~=23.9 should do the same here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it is, I do find >=23.9.1,<24 more explicit though (the thing is ~= is not consistent among package managers/languages... :/ see https://jubianchi.github.io/semver-check/#/~23.9/23.10 for instance).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I did not know this! Is this also the case in the python world?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so far, I think (and hope) it is fine (I'm mostly sticking to pip/mamba so I cannot say for the rest)

],
"docs": [
"sphinx>=5.3,<7.0",
Expand Down
1 change: 1 addition & 0 deletions stable_baselines3/common/vec_env/vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def set_venv(self, venv: VecEnv) -> None:
self.venv = venv
self.num_envs = venv.num_envs
self.class_attributes = dict(inspect.getmembers(self.__class__))
self.render_mode = venv.render_mode

# Check that the observation_space shape match
utils.check_shape_equal(self.observation_space, venv.observation_space)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.0a1
2.2.0a2
23 changes: 20 additions & 3 deletions tests/test_vec_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def make_env():
return Monitor(gym.make(ENV_ID))


def make_env_render():
return Monitor(gym.make(ENV_ID, render_mode="rgb_array"))


def make_dict_env():
return Monitor(DummyDictEnv())

Expand Down Expand Up @@ -257,14 +261,17 @@ def test_obs_rms_vec_normalize():
assert np.allclose(env.ret_rms.mean, 5.688, atol=1e-3)


@pytest.mark.parametrize("make_env", [make_env, make_dict_env, make_image_env])
def test_vec_env(tmp_path, make_env):
@pytest.mark.parametrize("make_gym_env", [make_env, make_dict_env, make_image_env])
def test_vec_env(tmp_path, make_gym_env):
"""Test VecNormalize Object"""
clip_obs = 0.5
clip_reward = 5.0

orig_venv = DummyVecEnv([make_env])
orig_venv = DummyVecEnv([make_gym_env])
norm_venv = VecNormalize(orig_venv, norm_obs=True, norm_reward=True, clip_obs=clip_obs, clip_reward=clip_reward)
assert orig_venv.render_mode is None
assert norm_venv.render_mode is None

_, done = norm_venv.reset(), [False]
while not done[0]:
actions = [norm_venv.action_space.sample()]
Expand All @@ -278,9 +285,19 @@ def test_vec_env(tmp_path, make_env):

path = tmp_path / "vec_normalize"
norm_venv.save(path)
assert orig_venv.render_mode is None
deserialized = VecNormalize.load(path, venv=orig_venv)
assert deserialized.render_mode is None
check_vec_norm_equal(norm_venv, deserialized)

# Check that render mode is properly updated
vec_env = DummyVecEnv([make_env_render])
assert vec_env.render_mode == "rgb_array"
# Test that loading and wrapping keep the correct render mode
if make_gym_env == make_env:
assert VecNormalize.load(path, venv=vec_env).render_mode == "rgb_array"
assert VecNormalize(vec_env).render_mode == "rgb_array"


def test_get_original():
venv = _make_warmstart_cartpole()
Expand Down