Skip to content

Commit

Permalink
WIP - working role-relative obs wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
alexpalms committed Sep 21, 2023
1 parent 63573f6 commit 3e6507b
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 11 deletions.
60 changes: 51 additions & 9 deletions diambra/arena/wrappers/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Mapping
import cv2 # pytype:disable=import-error
cv2.ocl.setUseOpenCL(False)
from diambra.engine import model
from diambra.engine import Roles

# Env Wrappers classes
class GrayscaleFrame(gym.ObservationWrapper):
Expand Down Expand Up @@ -274,17 +274,59 @@ def _obs_normalization_func(self, observation, observation_space):

return observation

class RoleRelativeObservation(gym.ObservationWrapper):
class RoleRelativeObservation(gym.Wrapper):
def __init__(self, env):
gym.ObservationWrapper.__init__(self, env)
gym.Wrapper.__init__(self, env)

new_observation_space = {}
if self.unwrapped.env_settings.n_players == 1:
for k, v in self.observation_space.items():
if not isinstance(v, gym.spaces.Dict):
new_observation_space[k] = v
new_observation_space["own"] = self.observation_space["P1"]
new_observation_space["opp"] = self.observation_space["P1"]
else:
for k, v in self.observation_space.items():
if not isinstance(v, gym.spaces.Dict) or k.startswith("agent_"):
new_observation_space[k] = v
for idx in range(self.unwrapped.env_settings.n_players):
new_observation_space["agent_{}".format(idx)]["own"] = self.observation_space["P1"]
new_observation_space["agent_{}".format(idx)]["opp"] = self.observation_space["P1"]

self.observation_space["own"] = self.observation_space["P1"]
self.observation_space["opp"] = self.observation_space["P1"]
del self.observation_space["P1"]
del self.observation_space["P2"]
self.observation_space = gym.spaces.Dict(new_observation_space)

def observation(self, observation):
return None
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
if self.unwrapped.env_settings.n_players == 1:
def _process_obs_1p(observation):
new_observation = {}
role_name = Roles.Name(info["settings"].episode_settings.player_settings[0].role)
opponent_role_name = "P2" if role_name == "P1" else "P1"
for k, v in observation.items():
if not isinstance(v, dict):
new_observation[k] = v
new_observation["own"] = observation[role_name]
new_observation["opp"] = observation[opponent_role_name]
return new_observation
self._process_obs = _process_obs_1p
else:
def _process_obs_2p(observation):
new_observation = {}
for k, v in observation.items():
if not isinstance(v, dict) or k.startswith("agent_"):
new_observation[k] = v
for idx in range(self.unwrapped.env_settings.n_players):
role_name = Roles.Name(info["settings"].episode_settings.player_settings[idx].role)
opponent_role_name = "P2" if role_name == "P1" else "P1"
new_observation["agent_{}".format(idx)]["own"] = observation[role_name]
new_observation["agent_{}".format(idx)]["opp"] = observation[opponent_role_name]
return new_observation
self._process_obs = _process_obs_2p
return self._process_obs(obs), info

def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._process_obs(obs), reward, terminated, truncated, info

"""
def rename_key_recursive(dictionary, old_key, new_key):
Expand Down
3 changes: 1 addition & 2 deletions examples/wrappers_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def main():
#wrappers_settings["exclude_image_scaling"] = True
#wrappers_settings["process_discrete_binary"] = True

"""
# If to make the observation relative to the agent as a function to its role (P1 or P2) (deactivate by default)
# i.e.:
# - In 1P environments, if the agent is P1 then the observation "P1" nesting level becomes "own" and "P2" becomes "opp"
Expand All @@ -74,7 +73,7 @@ def main():
# - Under "agent_1", "P1" nesting level becomes "opp" and "P2" becomes "own"
wrappers_settings["role_relative_observation"] = True

"""
# Flattening observation dictionary and filtering
# a sub-set of the RAM states
wrappers_settings["flatten"] = True
Expand Down

0 comments on commit 3e6507b

Please sign in to comment.