Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type annotations of buffers #1700

Merged
merged 8 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.2.0a5 (WIP)
Release 2.2.0a6 (WIP)
--------------------------

Breaking Changes:
Expand Down Expand Up @@ -49,6 +49,9 @@ Others:
- Fixed ``stable_baselines3/common/vec_env/vec_video_recorder.py`` type hints
- Fixed ``stable_baselines3/common/save_util.py`` type hints
- Updated docker images to Ubuntu Jammy using micromamba 1.5
- Fixed ``stable_baselines3/common/buffers.py`` type hints
- Fixed ``stable_baselines3/her/her_replay_buffer.py`` type hints
- Buffers do no call an additional ``.copy()`` when storing new transitions

Documentation:
^^^^^^^^^^^^^^
Expand Down
14 changes: 11 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,27 @@ line-length = 127
[tool.pytype]
inputs = ["stable_baselines3"]
disable = ["pyi-error"]
# Checked with mypy
exclude = [
"stable_baselines3/common/buffers.py",
"stable_baselines3/common/base_class.py",
"stable_baselines3/common/callbacks.py",
"stable_baselines3/common/on_policy_algorithm.py",
"stable_baselines3/common/vec_env/stacked_observations.py",
"stable_baselines3/common/vec_env/subproc_vec_env.py",
"stable_baselines3/common/vec_env/patch_gym.py"
]

[tool.mypy]
ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
stable_baselines3/common/buffers.py$
| stable_baselines3/common/distributions.py$
stable_baselines3/common/distributions.py$
| stable_baselines3/common/off_policy_algorithm.py$
| stable_baselines3/common/policies.py$
| stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/vec_normalize.py$
| stable_baselines3/her/her_replay_buffer.py$
| tests/test_logger.py$
| tests/test_train_eval_mode.py$
)"""
Expand Down
8 changes: 2 additions & 6 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,7 @@ def _setup_learn(
# Avoid resetting the environment when calling ``.learn()`` consecutive times
if reset_num_timesteps or self._last_obs is None:
assert self.env is not None
# pytype: disable=annotation-type-mismatch
self._last_obs = self.env.reset() # type: ignore[assignment]
# pytype: enable=annotation-type-mismatch
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
# Retrieve unnormalized observation for saving into the buffer
if self._vec_normalize_env is not None:
Expand Down Expand Up @@ -707,7 +705,7 @@ def load( # noqa: C901

# Gym -> Gymnasium space conversion
for key in {"observation_space", "action_space"}:
data[key] = _convert_space(data[key]) # pytype: disable=unsupported-operands
data[key] = _convert_space(data[key])

if env is not None:
# Wrap first if needed
Expand All @@ -726,14 +724,12 @@ def load( # noqa: C901
if "env" in data:
env = data["env"]

# pytype: disable=not-instantiable,wrong-keyword-args
model = cls(
policy=data["policy_class"],
env=env,
device=device,
_init_setup_model=False, # type: ignore[call-arg]
)
# pytype: enable=not-instantiable,wrong-keyword-args

# load parameters
model.__dict__.update(data)
Expand Down Expand Up @@ -776,7 +772,7 @@ def load( # noqa: C901
# Sample gSDE exploration matrix, so it uses the right device
# see issue #44
if model.use_sde:
model.policy.reset_noise() # type: ignore[operator] # pytype: disable=attribute-error
model.policy.reset_noise() # type: ignore[operator]
return model

def get_parameters(self) -> Dict[str, Dict]:
Expand Down
Loading