Skip to content

Commit

Permalink
(v3.3.8) - Observation normalization bug Fix (again), negative values…
Browse files Browse the repository at this point in the history
… in obs_rms.var (#422)

* Evl Callback: Using argument train_env instyead oh inhereted training environment

* Evl Callback: Fixed mean and var normalization calibration set (now it is applied correctly)

* Normalization Wrapper: Deleted RecordConstructorArgs

* Normalization wrapper: Deleted RecordConstructorArgs inherit

* Normalize Wrapper: Fixed var property bug (returning mean again instead of var)

* Updated Sinergym version from 3.3.7 to 3.3.8
  • Loading branch information
AlejandroCN7 authored Jul 3, 2024
1 parent 1aab03a commit d7d8c6e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
10 changes: 6 additions & 4 deletions sinergym/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def _on_step(self) -> bool:
self._is_success_buffer = []

# We close training env before to start the evaluation
self.training_env.close()
self.train_env.close()

self._sync_envs()

Expand All @@ -375,7 +375,7 @@ def _on_step(self) -> bool:

# We close evaluation env and starts training env again
self.eval_env.close()
self.training_env.reset()
self.train_env.reset()

if self.log_path is not None:
for key, value in episodes_data.items():
Expand Down Expand Up @@ -502,5 +502,7 @@ def _sync_envs(self):
self.eval_env,
NormalizeObservation):
self.eval_env.get_wrapper_attr('deactivate_update')()
self.eval_env.obs_rms = deepcopy(
self.train_env.get_wrapper_attr('obs_rms'))
self.eval_env.get_wrapper_attr('set_mean')(
self.train_env.get_wrapper_attr('mean'))
self.eval_env.get_wrapper_attr('set_var')(
self.train_env.get_wrapper_attr('var'))
8 changes: 2 additions & 6 deletions sinergym/utils/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def step(self, action: Union[int, np.ndarray]) -> Tuple[
return obs, reward_vector, terminated, truncated, info


class NormalizeObservation(gym.Wrapper, gym.utils.RecordConstructorArgs):
class NormalizeObservation(gym.Wrapper):

logger = Logger().getLogger(name='WRAPPER NormalizeObservation',
level=LOG_WRAPPERS_LEVEL)
Expand All @@ -82,10 +82,6 @@ def __init__(self,
mean = self._check_and_update_metric(mean, 'mean')
var = self._check_and_update_metric(var, 'var')

# Save normalization configuration for whole python process
gym.utils.RecordConstructorArgs.__init__(
self, epsilon=epsilon, mean=mean, var=var)

self.num_envs = 1
self.is_vector_env = False
self.automatic_update = automatic_update
Expand Down Expand Up @@ -196,7 +192,7 @@ def mean(self) -> Optional[np.float64]:
def var(self) -> Optional[np.float64]:
"""Returns the variance value of the observations."""
if hasattr(self, 'obs_rms'):
return self.obs_rms.mean
return self.obs_rms.var
else:
return None

Expand Down
2 changes: 1 addition & 1 deletion sinergym/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.3.7
3.3.8

0 comments on commit d7d8c6e

Please sign in to comment.