From bb13aefe4da3d1da5556f79527be8f4e79ad756e Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 26 Sep 2023 17:44:49 +0200 Subject: [PATCH 1/6] Fix type annotation and replay buffer --- stable_baselines3/common/buffers.py | 99 +++++++++++-------- stable_baselines3/common/preprocessing.py | 2 +- .../common/sb2_compat/rmsprop_tf_like.py | 2 +- 3 files changed, 61 insertions(+), 42 deletions(-) diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 576e10a8b..2e8f034c6 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -1,6 +1,6 @@ import warnings from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, List, Optional, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Union import numpy as np import torch as th @@ -35,6 +35,9 @@ class BaseBuffer(ABC): :param n_envs: Number of parallel environments """ + observation_space: spaces.Space + obs_shape: Tuple[int, ...] + def __init__( self, buffer_size: int, @@ -47,7 +50,7 @@ def __init__( self.buffer_size = buffer_size self.observation_space = observation_space self.action_space = action_space - self.obs_shape = get_obs_shape(observation_space) + self.obs_shape = get_obs_shape(observation_space) # type: ignore[assignment] self.action_dim = get_action_dim(action_space) self.pos = 0 @@ -171,6 +174,13 @@ class ReplayBuffer(BaseBuffer): https://github.com/DLR-RM/stable-baselines3/issues/284 """ + observations: np.ndarray + next_observations: np.ndarray + actions: np.ndarray + rewards: np.ndarray + dones: np.ndarray + timeouts: np.ndarray + def __init__( self, buffer_size: int, @@ -201,10 +211,8 @@ def __init__( self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype) - if optimize_memory_usage: - # `observations` contains also the next observation - self.next_observations = None - else: + if not optimize_memory_usage: + # When optimizing memory, `observations` contains also the next observation self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype) self.actions = np.zeros( @@ -219,7 +227,9 @@ def __init__( self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) if psutil is not None: - total_memory_usage = self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes + total_memory_usage: float = ( + self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes + ) if self.next_observations is not None: total_memory_usage += self.next_observations.nbytes @@ -252,16 +262,16 @@ def add( action = action.reshape((self.n_envs, self.action_dim)) # Copy to avoid modification by reference - self.observations[self.pos] = np.array(obs).copy() + self.observations[self.pos] = np.array(obs) if self.optimize_memory_usage: - self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs).copy() + self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs) else: - self.next_observations[self.pos] = np.array(next_obs).copy() + self.next_observations[self.pos] = np.array(next_obs) - self.actions[self.pos] = np.array(action).copy() - self.rewards[self.pos] = np.array(reward).copy() - self.dones[self.pos] = np.array(done).copy() + self.actions[self.pos] = np.array(action) + self.rewards[self.pos] = np.array(reward) + self.dones[self.pos] = np.array(done) if self.handle_timeout_termination: self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos]) @@ -457,10 +467,10 @@ def add( # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392 action = action.reshape((self.n_envs, self.action_dim)) - self.observations[self.pos] = np.array(obs).copy() - self.actions[self.pos] = np.array(action).copy() - self.rewards[self.pos] = np.array(reward).copy() - self.episode_starts[self.pos] = np.array(episode_start).copy() + self.observations[self.pos] = np.array(obs) + self.actions[self.pos] = np.array(action) + self.rewards[self.pos] = np.array(reward) + self.episode_starts[self.pos] = np.array(episode_start) self.values[self.pos] = value.clone().cpu().numpy().flatten() self.log_probs[self.pos] = log_prob.clone().cpu().numpy() self.pos += 1 @@ -527,10 +537,15 @@ class DictReplayBuffer(ReplayBuffer): https://github.com/DLR-RM/stable-baselines3/issues/284 """ + observation_space: spaces.Dict + obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment] + observations: Dict[str, np.ndarray] # type: ignore[assignment] + next_observations: Dict[str, np.ndarray] # type: ignore[assignment] + def __init__( self, buffer_size: int, - observation_space: spaces.Space, + observation_space: spaces.Dict, action_space: spaces.Space, device: Union[th.device, str] = "auto", n_envs: int = 1, @@ -576,8 +591,8 @@ def __init__( for _, obs in self.observations.items(): obs_nbytes += obs.nbytes - total_memory_usage = obs_nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes - if self.next_observations is not None: + total_memory_usage: float = obs_nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes + if not optimize_memory_usage: next_obs_nbytes = 0 for _, obs in self.observations.items(): next_obs_nbytes += obs.nbytes @@ -592,7 +607,7 @@ def __init__( f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB" ) - def add( + def add( # type: ignore[override] self, obs: Dict[str, np.ndarray], next_obs: Dict[str, np.ndarray], @@ -612,14 +627,14 @@ def add( for key in self.next_observations.keys(): if isinstance(self.observation_space.spaces[key], spaces.Discrete): next_obs[key] = next_obs[key].reshape((self.n_envs,) + self.obs_shape[key]) - self.next_observations[key][self.pos] = np.array(next_obs[key]).copy() + self.next_observations[key][self.pos] = np.array(next_obs[key]) # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392 action = action.reshape((self.n_envs, self.action_dim)) - self.actions[self.pos] = np.array(action).copy() - self.rewards[self.pos] = np.array(reward).copy() - self.dones[self.pos] = np.array(done).copy() + self.actions[self.pos] = np.array(action) + self.rewards[self.pos] = np.array(reward) + self.dones[self.pos] = np.array(done) if self.handle_timeout_termination: self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos]) @@ -629,11 +644,11 @@ def add( self.full = True self.pos = 0 - def sample( + def sample( # type: ignore[override] self, batch_size: int, env: Optional[VecNormalize] = None, - ) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME: + ) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] """ Sample elements from the replay buffer. @@ -644,7 +659,7 @@ def sample( """ return super(ReplayBuffer, self).sample(batch_size=batch_size, env=env) - def _get_samples( + def _get_samples( # type: ignore[override] self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None, @@ -658,6 +673,8 @@ def _get_samples( {key: obs[batch_inds, env_indices, :] for key, obs in self.next_observations.items()}, env ) + assert isinstance(obs_, dict) + assert isinstance(next_obs_, dict) # Convert to torch tensor observations = {key: self.to_torch(obs) for key, obs in obs_.items()} next_observations = {key: self.to_torch(obs) for key, obs in next_obs_.items()} @@ -700,12 +717,14 @@ class DictRolloutBuffer(RolloutBuffer): :param n_envs: Number of parallel environments """ - observations: Dict[str, np.ndarray] + observation_space: spaces.Dict + obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment] + observations: Dict[str, np.ndarray] # type: ignore[assignment] def __init__( self, buffer_size: int, - observation_space: spaces.Space, + observation_space: spaces.Dict, action_space: spaces.Space, device: Union[th.device, str] = "auto", gae_lambda: float = 1, @@ -723,7 +742,7 @@ def __init__( self.reset() def reset(self) -> None: - assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" + # assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" self.observations = {} for key, obs_input_shape in self.obs_shape.items(): self.observations[key] = np.zeros((self.buffer_size, self.n_envs, *obs_input_shape), dtype=np.float32) @@ -737,7 +756,7 @@ def reset(self) -> None: self.generator_ready = False super(RolloutBuffer, self).reset() - def add( + def add( # type: ignore[override] self, obs: Dict[str, np.ndarray], action: np.ndarray, @@ -761,7 +780,7 @@ def add( log_prob = log_prob.reshape(-1, 1) for key in self.observations.keys(): - obs_ = np.array(obs[key]).copy() + obs_ = np.array(obs[key]) # Reshape needed when using multiple envs with discrete observations # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) if isinstance(self.observation_space.spaces[key], spaces.Discrete): @@ -771,19 +790,19 @@ def add( # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392 action = action.reshape((self.n_envs, self.action_dim)) - self.actions[self.pos] = np.array(action).copy() - self.rewards[self.pos] = np.array(reward).copy() - self.episode_starts[self.pos] = np.array(episode_start).copy() + self.actions[self.pos] = np.array(action) + self.rewards[self.pos] = np.array(reward) + self.episode_starts[self.pos] = np.array(episode_start) self.values[self.pos] = value.clone().cpu().numpy().flatten() self.log_probs[self.pos] = log_prob.clone().cpu().numpy() self.pos += 1 if self.pos == self.buffer_size: self.full = True - def get( + def get( # type: ignore[override] self, batch_size: Optional[int] = None, - ) -> Generator[DictRolloutBufferSamples, None, None]: # type: ignore[signature-mismatch] #FIXME + ) -> Generator[DictRolloutBufferSamples, None, None]: # type: ignore[signature-mismatch] assert self.full, "" indices = np.random.permutation(self.buffer_size * self.n_envs) # Prepare the data @@ -806,11 +825,11 @@ def get( yield self._get_samples(indices[start_idx : start_idx + batch_size]) start_idx += batch_size - def _get_samples( + def _get_samples( # type: ignore[override] self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None, - ) -> DictRolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME + ) -> DictRolloutBufferSamples: # type: ignore[signature-mismatch] return DictRolloutBufferSamples( observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, actions=self.to_torch(self.actions[batch_inds]), diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index bc0959480..a2d0e59c1 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -112,7 +112,7 @@ def preprocess_obs( elif isinstance(observation_space, spaces.Discrete): # One hot encoding and convert to float to avoid errors - return F.one_hot(obs.long(), num_classes=observation_space.n).float() + return F.one_hot(obs.long(), num_classes=int(observation_space.n)).float() elif isinstance(observation_space, spaces.MultiDiscrete): # Tensor concatenation of one hot encodings of each Categorical sub-space diff --git a/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py b/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py index 9d74798e0..25f0a6f96 100644 --- a/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py +++ b/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py @@ -74,7 +74,7 @@ def __setstate__(self, state: Dict[str, Any]) -> None: group.setdefault("centered", False) @torch.no_grad() - def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: # type: ignore[override] """Performs a single optimization step. :param closure: A closure that reevaluates the model From 735741e6202594f19514de1241090b766a8ba0ed Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 27 Sep 2023 11:16:41 +0200 Subject: [PATCH 2/6] Exclude pytype check --- pyproject.toml | 5 +++-- stable_baselines3/common/buffers.py | 14 +++++++------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7e5d2b629..d8afd5b1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,14 +27,15 @@ line-length = 127 [tool.pytype] inputs = ["stable_baselines3"] disable = ["pyi-error"] +# Checked with mypy +exclude = ["stable_baselines3/common/buffers.py"] [tool.mypy] ignore_missing_imports = true follow_imports = "silent" show_error_codes = true exclude = """(?x)( - stable_baselines3/common/buffers.py$ - | stable_baselines3/common/distributions.py$ + stable_baselines3/common/distributions.py$ | stable_baselines3/common/off_policy_algorithm.py$ | stable_baselines3/common/policies.py$ | stable_baselines3/common/vec_env/__init__.py$ diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 2e8f034c6..944a89dc8 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -508,7 +508,7 @@ def _get_samples( self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None, - ) -> RolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME + ) -> RolloutBufferSamples: #FIXME data = ( self.observations[batch_inds], self.actions[batch_inds], @@ -615,7 +615,7 @@ def add( # type: ignore[override] reward: np.ndarray, done: np.ndarray, infos: List[Dict[str, Any]], - ) -> None: # pytype: disable=signature-mismatch + ) -> None: # Copy to avoid modification by reference for key in self.observations.keys(): # Reshape needed when using multiple envs with discrete observations @@ -648,7 +648,7 @@ def sample( # type: ignore[override] self, batch_size: int, env: Optional[VecNormalize] = None, - ) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] + ) -> DictReplayBufferSamples: """ Sample elements from the replay buffer. @@ -663,7 +663,7 @@ def _get_samples( # type: ignore[override] self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None, - ) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME: + ) -> DictReplayBufferSamples: # Sample randomly the env idx env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),)) @@ -764,7 +764,7 @@ def add( # type: ignore[override] episode_start: np.ndarray, value: th.Tensor, log_prob: th.Tensor, - ) -> None: # pytype: disable=signature-mismatch + ) -> None: """ :param obs: Observation :param action: Action @@ -802,7 +802,7 @@ def add( # type: ignore[override] def get( # type: ignore[override] self, batch_size: Optional[int] = None, - ) -> Generator[DictRolloutBufferSamples, None, None]: # type: ignore[signature-mismatch] + ) -> Generator[DictRolloutBufferSamples, None, None]: assert self.full, "" indices = np.random.permutation(self.buffer_size * self.n_envs) # Prepare the data @@ -829,7 +829,7 @@ def _get_samples( # type: ignore[override] self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None, - ) -> DictRolloutBufferSamples: # type: ignore[signature-mismatch] + ) -> DictRolloutBufferSamples: return DictRolloutBufferSamples( observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, actions=self.to_torch(self.actions[batch_inds]), From 15dd3632a897272d3fda0de4b6ac4a855851bbdc Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 27 Sep 2023 11:36:24 +0200 Subject: [PATCH 3/6] Remove some pytype specific annotaiton and update changelog --- docs/misc/changelog.rst | 4 +++- pyproject.toml | 10 +++++++++- stable_baselines3/common/base_class.py | 8 ++------ stable_baselines3/common/buffers.py | 2 +- stable_baselines3/common/callbacks.py | 4 ++-- stable_baselines3/common/on_policy_algorithm.py | 4 +--- stable_baselines3/common/vec_env/patch_gym.py | 6 +++--- .../common/vec_env/stacked_observations.py | 9 ++------- stable_baselines3/common/vec_env/subproc_vec_env.py | 2 -- stable_baselines3/version.txt | 2 +- 10 files changed, 24 insertions(+), 27 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 393291873..c5d94d130 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.2.0a5 (WIP) +Release 2.2.0a6 (WIP) -------------------------- Breaking Changes: @@ -48,6 +48,8 @@ Others: - Fixed ``stable_baselines3/common/vec_env/vec_video_recorder.py`` type hints - Fixed ``stable_baselines3/common/save_util.py`` type hints - Updated docker images to Ubuntu Jammy using micromamba 1.5 +- Fixed ``stable_baselines3/common/buffers.py`` type hints +- Buffers do no call an additional ``.copy()`` when storing new transitions Documentation: ^^^^^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index d8afd5b1a..5616b5d84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,15 @@ line-length = 127 inputs = ["stable_baselines3"] disable = ["pyi-error"] # Checked with mypy -exclude = ["stable_baselines3/common/buffers.py"] +exclude = [ + "stable_baselines3/common/buffers.py", + "stable_baselines3/common/base_class.py", + "stable_baselines3/common/callbacks.py", + "stable_baselines3/common/on_policy_algorithm.py", + "stable_baselines3/common/vec_env/stacked_observations.py", + "stable_baselines3/common/vec_env/subproc_vec_env.py", + "stable_baselines3/common/vec_env/patch_gym.py" +] [tool.mypy] ignore_missing_imports = true diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 9f587063f..5e8759990 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -420,9 +420,7 @@ def _setup_learn( # Avoid resetting the environment when calling ``.learn()`` consecutive times if reset_num_timesteps or self._last_obs is None: assert self.env is not None - # pytype: disable=annotation-type-mismatch self._last_obs = self.env.reset() # type: ignore[assignment] - # pytype: enable=annotation-type-mismatch self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool) # Retrieve unnormalized observation for saving into the buffer if self._vec_normalize_env is not None: @@ -707,7 +705,7 @@ def load( # noqa: C901 # Gym -> Gymnasium space conversion for key in {"observation_space", "action_space"}: - data[key] = _convert_space(data[key]) # pytype: disable=unsupported-operands + data[key] = _convert_space(data[key]) if env is not None: # Wrap first if needed @@ -726,14 +724,12 @@ def load( # noqa: C901 if "env" in data: env = data["env"] - # pytype: disable=not-instantiable,wrong-keyword-args model = cls( policy=data["policy_class"], env=env, device=device, _init_setup_model=False, # type: ignore[call-arg] ) - # pytype: enable=not-instantiable,wrong-keyword-args # load parameters model.__dict__.update(data) @@ -776,7 +772,7 @@ def load( # noqa: C901 # Sample gSDE exploration matrix, so it uses the right device # see issue #44 if model.use_sde: - model.policy.reset_noise() # type: ignore[operator] # pytype: disable=attribute-error + model.policy.reset_noise() # type: ignore[operator] return model def get_parameters(self) -> Dict[str, Dict]: diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 944a89dc8..263d8e100 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -508,7 +508,7 @@ def _get_samples( self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None, - ) -> RolloutBufferSamples: #FIXME + ) -> RolloutBufferSamples: data = ( self.observations[batch_inds], self.actions[batch_inds], diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index f16b57976..54f1b97e5 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -19,7 +19,7 @@ # if the progress bar is used tqdm = None -from stable_baselines3.common import base_class # pytype: disable=pyi-error +from stable_baselines3.common import base_class from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization @@ -680,7 +680,7 @@ class ProgressBarCallback(BaseCallback): using tqdm and rich packages. """ - pbar: tqdm # pytype: disable=invalid-annotation + pbar: tqdm def __init__(self) -> None: super().__init__() diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index a3f47b2b2..1e0f9e6c9 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -112,18 +112,16 @@ def _setup_model(self) -> None: self.rollout_buffer = buffer_cls( self.n_steps, - self.observation_space, + self.observation_space, # type: ignore[arg-type] self.action_space, device=self.device, gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs, ) - # pytype:disable=not-instantiable self.policy = self.policy_class( # type: ignore[assignment] self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs ) - # pytype:enable=not-instantiable self.policy = self.policy.to(self.device) def collect_rollouts( diff --git a/stable_baselines3/common/vec_env/patch_gym.py b/stable_baselines3/common/vec_env/patch_gym.py index 7b1934bee..2da76a9b2 100644 --- a/stable_baselines3/common/vec_env/patch_gym.py +++ b/stable_baselines3/common/vec_env/patch_gym.py @@ -5,7 +5,7 @@ import gymnasium try: - import gym # pytype: disable=import-error + import gym gym_installed = True except ImportError: @@ -37,7 +37,7 @@ def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma ) try: - import shimmy # pytype: disable=import-error + import shimmy except ImportError as e: raise ImportError( "Missing shimmy installation. You provided an OpenAI Gym environment. " @@ -83,7 +83,7 @@ def _convert_space(space: Union["gym.Space", gymnasium.Space]) -> gymnasium.Spac ) try: - import shimmy # pytype: disable=import-error + import shimmy except ImportError as e: raise ImportError( "Missing shimmy installation. You provided an OpenAI Gym space. " diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index bf375e165..b6a759f30 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -9,9 +9,6 @@ TObs = TypeVar("TObs", np.ndarray, Dict[str, np.ndarray]) -# Disable errors for pytype which doesn't play well with Generic[TypeVar] -# mypy check passes though -# pytype: disable=attribute-error class StackedObservations(Generic[TObs]): """ Frame stacking wrapper for data. @@ -109,16 +106,14 @@ def reset(self, observation: TObs) -> TObs: :return: The stacked reset observation """ if isinstance(observation, dict): - return { - key: self.sub_stacked_observations[key].reset(obs) for key, obs in observation.items() - } # pytype: disable=bad-return-type + return {key: self.sub_stacked_observations[key].reset(obs) for key, obs in observation.items()} self.stacked_obs[...] = 0 if self.channels_first: self.stacked_obs[:, -observation.shape[self.stack_dimension] :, ...] = observation else: self.stacked_obs[..., -observation.shape[self.stack_dimension] :] = observation - return self.stacked_obs # pytype: disable=bad-return-type + return self.stacked_obs def update( self, diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index cc8ffdbe4..dbc7002f0 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -109,9 +109,7 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[ for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns): args = (work_remote, remote, CloudpickleWrapper(env_fn)) # daemon=True: if the main process crashes, we should not cause things to hang - # pytype: disable=attribute-error process = ctx.Process(target=_worker, args=args, daemon=True) # type: ignore[attr-defined] - # pytype: enable=attribute-error process.start() self.processes.append(process) work_remote.close() diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 210ed6b9b..47f323d67 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.2.0a5 +2.2.0a6 From ce56f84ffb8e996929e28d80e861c1dd43032177 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 27 Sep 2023 12:18:52 +0200 Subject: [PATCH 4/6] Fix HerReplayBuffer type hints --- docs/misc/changelog.rst | 1 + pyproject.toml | 1 - stable_baselines3/common/buffers.py | 3 +-- stable_baselines3/her/her_replay_buffer.py | 25 ++++++++++++++-------- 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index c2ad4cc0c..3fade42aa 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -50,6 +50,7 @@ Others: - Fixed ``stable_baselines3/common/save_util.py`` type hints - Updated docker images to Ubuntu Jammy using micromamba 1.5 - Fixed ``stable_baselines3/common/buffers.py`` type hints +- Fixed ``stable_baselines3/her/her_replay_buffer.py`` type hints - Buffers do no call an additional ``.copy()`` when storing new transitions Documentation: diff --git a/pyproject.toml b/pyproject.toml index 5616b5d84..2d0c61914 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,6 @@ exclude = """(?x)( | stable_baselines3/common/policies.py$ | stable_baselines3/common/vec_env/__init__.py$ | stable_baselines3/common/vec_env/vec_normalize.py$ - | stable_baselines3/her/her_replay_buffer.py$ | tests/test_logger.py$ | tests/test_train_eval_mode.py$ )""" diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 263d8e100..a230a31b8 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -231,7 +231,7 @@ def __init__( self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes ) - if self.next_observations is not None: + if not optimize_memory_usage: total_memory_usage += self.next_observations.nbytes if total_memory_usage > mem_available: @@ -742,7 +742,6 @@ def __init__( self.reset() def reset(self) -> None: - # assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" self.observations = {} for key, obs_input_shape in self.obs_shape.items(): self.observations[key] = np.zeros((self.buffer_size, self.n_envs, *obs_input_shape), dtype=np.float32) diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index 73a4a01f6..5f0765884 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -7,7 +7,7 @@ from gymnasium import spaces from stable_baselines3.common.buffers import DictReplayBuffer -from stable_baselines3.common.type_aliases import DictReplayBufferSamples, TensorDict +from stable_baselines3.common.type_aliases import DictReplayBufferSamples from stable_baselines3.common.vec_env import VecEnv, VecNormalize from stable_baselines3.her.goal_selection_strategy import KEY_TO_GOAL_STRATEGY, GoalSelectionStrategy @@ -45,10 +45,12 @@ class HerReplayBuffer(DictReplayBuffer): False by default. """ + env: Optional[VecEnv] + def __init__( self, buffer_size: int, - observation_space: spaces.Space, + observation_space: spaces.Dict, action_space: spaces.Space, env: VecEnv, device: Union[th.device, str] = "auto", @@ -130,10 +132,10 @@ def set_env(self, env: VecEnv) -> None: self.env = env - def add( + def add( # type: ignore[override] self, - obs: TensorDict, - next_obs: TensorDict, + obs: Dict[str, np.ndarray], + next_obs: Dict[str, np.ndarray], action: np.ndarray, reward: np.ndarray, done: np.ndarray, @@ -181,7 +183,7 @@ def _compute_episode_length(self, env_idx: int) -> None: # Update the current episode start self._current_ep_start[env_idx] = self.pos - def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples: + def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples: # type: ignore[override] """ Sample elements from the replay buffer. @@ -264,6 +266,8 @@ def _get_real_samples( {key: obs[batch_indices, env_indices, :] for key, obs in self.next_observations.items()}, env ) + assert isinstance(obs_, dict) + assert isinstance(next_obs_, dict) # Convert to torch tensor observations = {key: self.to_torch(obs) for key, obs in obs_.items()} next_observations = {key: self.to_torch(obs) for key, obs in next_obs_.items()} @@ -309,6 +313,9 @@ def _get_virtual_samples( # The desired goal for the next observation must be the same as the previous one next_obs["desired_goal"] = new_goals + assert ( + self.env is not None + ), "You must initialize HerReplayBuffer with a VecEnv so it can compute rewards for virtual transitions" # Compute new reward rewards = self.env.env_method( "compute_reward", @@ -326,8 +333,8 @@ def _get_virtual_samples( indices=[0], ) rewards = rewards[0].astype(np.float32) # env_method returns a list containing one element - obs = self._normalize_obs(obs, env) - next_obs = self._normalize_obs(next_obs, env) + obs = self._normalize_obs(obs, env) # type: ignore[assignment] + next_obs = self._normalize_obs(next_obs, env) # type: ignore[assignment] # Convert to torch tensor observations = {key: self.to_torch(obs) for key, obs in obs.items()} @@ -342,7 +349,7 @@ def _get_virtual_samples( dones=self.to_torch( self.dones[batch_indices, env_indices] * (1 - self.timeouts[batch_indices, env_indices]) ).reshape(-1, 1), - rewards=self.to_torch(self._normalize_reward(rewards.reshape(-1, 1), env)), + rewards=self.to_torch(self._normalize_reward(rewards.reshape(-1, 1), env)), # type: ignore[attr-defined] ) def _sample_goals(self, batch_indices: np.ndarray, env_indices: np.ndarray) -> np.ndarray: From 3616b1bcbd3e42d2bb9b56877b8b2cc5bd79ee01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 28 Sep 2023 00:54:55 +0200 Subject: [PATCH 5/6] try remove # type: ignore[assignment] --- stable_baselines3/common/buffers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index a230a31b8..6db56ed5f 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -540,7 +540,7 @@ class DictReplayBuffer(ReplayBuffer): observation_space: spaces.Dict obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment] observations: Dict[str, np.ndarray] # type: ignore[assignment] - next_observations: Dict[str, np.ndarray] # type: ignore[assignment] + next_observations: Dict[str, np.ndarray] def __init__( self, From eee2fcbdcc4f77db49c721105cc990ea71dd703a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 28 Sep 2023 03:21:48 +0200 Subject: [PATCH 6/6] revert change --- stable_baselines3/common/buffers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 6db56ed5f..a230a31b8 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -540,7 +540,7 @@ class DictReplayBuffer(ReplayBuffer): observation_space: spaces.Dict obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment] observations: Dict[str, np.ndarray] # type: ignore[assignment] - next_observations: Dict[str, np.ndarray] + next_observations: Dict[str, np.ndarray] # type: ignore[assignment] def __init__( self,