Skip to content

Commit

Permalink
Manage the original observation as a property in Normalization wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
AlejandroCN7 committed Sep 14, 2023
1 parent 9e6eb21 commit 9a1968b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 12 deletions.
2 changes: 1 addition & 1 deletion sinergym/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _on_step(self) -> bool:
if self.training_env.env_is_wrapped(
wrapper_class=NormalizeObservation)[0]:
obs_normalized = self.locals['new_obs'][-1]
obs = self.training_env.env_method('get_unwrapped_obs')[-1]
obs = self.training_env.get_attr('unwrapped_observation')[-1]
for i, variable in enumerate(observation_variables):
self.logger.record(
'normalized_observation/' + variable, obs_normalized[i])
Expand Down
14 changes: 3 additions & 11 deletions sinergym/utils/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,6 @@ def normalize(self, obs):
return (obs - self.obs_rms.mean) / \
np.sqrt(self.obs_rms.var + self.epsilon)

def get_unwrapped_obs(self) -> Optional[np.ndarray]:
"""Get last environment observation without normalization.
Returns:
Optional[np.ndarray]: Last original observation. If it is the first observation, this value is None.
"""
return self.get_wrapper_attr('unwrapped_observation')


class MultiObsWrapper(gym.Wrapper):

Expand Down Expand Up @@ -266,7 +258,7 @@ def step(self, action: Union[int, np.ndarray]
info=info)
# Record original observation too
self.file_logger.log_step(
obs=self.env.get_unwrapped_obs(),
obs=self.env.get_wrapper_attr('unwrapped_observation'),
action=info['action'],
terminated=terminated,
info=info)
Expand Down Expand Up @@ -317,7 +309,7 @@ def reset(self,
self.file_logger.log_step_normalize(obs=obs, action=[None for _ in range(len(
self.env.get_wrapper_attr('action_variables')))], terminated=False, info=info)
# And store original obs
self.file_logger.log_step(obs=self.env.get_unwrapped_obs(),
self.file_logger.log_step(obs=self.env.get_wrapper_attr('unwrapped_observation'),
action=[None for _ in range(
len(self.get_wrapper_attr('action_variables')))],
terminated=False,
Expand Down Expand Up @@ -603,7 +595,7 @@ def action(self, act: np.ndarray) -> np.ndarray:
Returns:
np.ndarray: Action Clipped
"""
if self.flag_discrete:
if self.get_wrapper_attr('flag_discrete'):
null_value = 0.0
else:
# -1.0 is 0.0 when action space transformation to simulator action space.
Expand Down

0 comments on commit 9a1968b

Please sign in to comment.