Skip to content

Commit

Permalink
Add MO MaxAndSkipWrapper (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasAlegre authored Aug 24, 2023
1 parent 31d510f commit fc08232
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions mo_gymnasium/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,56 @@ def step(self, action):
truncations,
infos,
)


class MOMaxAndSkipObservation(gym.Wrapper):
"""This wrapper will return only every ``skip``-th frame (frameskipping) and return the max between the two last observations.
Note: This wrapper is based on the wrapper from stable-baselines3: https://stable-baselines3.readthedocs.io/en/master/_modules/stable_baselines3/common/atari_wrappers.html#MaxAndSkipEnv
"""

def __init__(self, env: gym.Env[ObsType, ActType], skip: int = 4):
"""This wrapper will return only every ``skip``-th frame (frameskipping) and return the max between the two last frames.
Args:
env (Env): The environment to apply the wrapper
skip: The number of frames to skip
"""
gym.Wrapper.__init__(self, env)

if not np.issubdtype(type(skip), np.integer):
raise TypeError(f"The skip is expected to be an integer, actual type: {type(skip)}")
if skip < 2:
raise ValueError(f"The skip value needs to be equal or greater than two, actual value: {skip}")
if env.observation_space.shape is None:
raise ValueError("The observation space must have the shape attribute.")

self._skip = skip
self._obs_buffer = np.zeros((2, *env.observation_space.shape), dtype=env.observation_space.dtype)

def step(self, action):
"""Step the environment with the given action for ``skip`` steps.
Repeat action, sum reward, and max over last observations.
Args:
action: The action to step through the environment with
Returns:
Max of the last two observations, reward, terminated, truncated, and info from the environment
"""
total_reward = np.zeros(self.env.reward_dim, dtype=np.float32)
terminated = truncated = False
info = {}
for i in range(self._skip):
obs, reward, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
if i == self._skip - 2:
self._obs_buffer[0] = obs
if i == self._skip - 1:
self._obs_buffer[1] = obs
total_reward += reward
if done:
break
max_frame = self._obs_buffer.max(axis=0)

return max_frame, total_reward, terminated, truncated, info

0 comments on commit fc08232

Please sign in to comment.