diff --git a/sinergym/utils/callbacks.py b/sinergym/utils/callbacks.py index bf0a91cda..8b2182abf 100644 --- a/sinergym/utils/callbacks.py +++ b/sinergym/utils/callbacks.py @@ -359,7 +359,7 @@ def _on_step(self) -> bool: self._is_success_buffer = [] # We close training env before to start the evaluation - self.training_env.close() + self.train_env.close() self._sync_envs() @@ -375,7 +375,7 @@ def _on_step(self) -> bool: # We close evaluation env and starts training env again self.eval_env.close() - self.training_env.reset() + self.train_env.reset() if self.log_path is not None: for key, value in episodes_data.items(): @@ -502,5 +502,7 @@ def _sync_envs(self): self.eval_env, NormalizeObservation): self.eval_env.get_wrapper_attr('deactivate_update')() - self.eval_env.obs_rms = deepcopy( - self.train_env.get_wrapper_attr('obs_rms')) + self.eval_env.get_wrapper_attr('set_mean')( + self.train_env.get_wrapper_attr('mean')) + self.eval_env.get_wrapper_attr('set_var')( + self.train_env.get_wrapper_attr('var')) diff --git a/sinergym/utils/wrappers.py b/sinergym/utils/wrappers.py index d8d2fef19..b78af6903 100644 --- a/sinergym/utils/wrappers.py +++ b/sinergym/utils/wrappers.py @@ -56,7 +56,7 @@ def step(self, action: Union[int, np.ndarray]) -> Tuple[ return obs, reward_vector, terminated, truncated, info -class NormalizeObservation(gym.Wrapper, gym.utils.RecordConstructorArgs): +class NormalizeObservation(gym.Wrapper): logger = Logger().getLogger(name='WRAPPER NormalizeObservation', level=LOG_WRAPPERS_LEVEL) @@ -82,10 +82,6 @@ def __init__(self, mean = self._check_and_update_metric(mean, 'mean') var = self._check_and_update_metric(var, 'var') - # Save normalization configuration for whole python process - gym.utils.RecordConstructorArgs.__init__( - self, epsilon=epsilon, mean=mean, var=var) - self.num_envs = 1 self.is_vector_env = False self.automatic_update = automatic_update @@ -196,7 +192,7 @@ def mean(self) -> Optional[np.float64]: def var(self) -> Optional[np.float64]: """Returns the variance value of the observations.""" if hasattr(self, 'obs_rms'): - return self.obs_rms.mean + return self.obs_rms.var else: return None diff --git a/sinergym/version.txt b/sinergym/version.txt index 010d183f8..7cb75caa9 100644 --- a/sinergym/version.txt +++ b/sinergym/version.txt @@ -1 +1 @@ -3.3.7 \ No newline at end of file +3.3.8 \ No newline at end of file