Skip to content

Commit

Permalink
[Feature] Allow multiple observation keys 2 (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini authored Apr 30, 2024
1 parent 8b2197c commit 0a28a16
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
10 changes: 8 additions & 2 deletions benchmarl/algorithms/qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,19 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:

def get_mixer(self, group: str) -> TensorDictModule:
n_agents = len(self.group_map[group])
group_observation_key = list(self.observation_spec[group].keys())[0]

if self.state_spec is not None:
global_state_key = list(self.state_spec.keys())[0]
global_state_key = list(self.state_spec.keys(True, True))[0]
state_shape = self.state_spec[global_state_key].shape
in_keys = [(group, "chosen_action_value"), global_state_key]
else:
group_observation_keys = list(self.observation_spec[group].keys(True, True))
if len(group_observation_keys) > 1:
raise ValueError(
"QMIX called without a global state and multiple observation keys, currently the mixer"
"takes only one observation key, please raise an issue if you need this fauture."
)
group_observation_key = group_observation_keys[0]
state_shape = self.observation_spec[group, group_observation_key].shape
in_keys = [(group, "chosen_action_value"), (group, group_observation_key)]

Expand Down
23 changes: 22 additions & 1 deletion benchmarl/environments/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,32 @@ def group_map(self, env: EnvBase) -> Dict[str, List[str]]:
def observation_spec(self, env: EnvBase) -> CompositeSpec:
"""
A spec for the observation.
Must be a CompositeSpec with one (group_name, observation_key) entry per group.
Must be a CompositeSpec with as many entries as needed nested under the ``group_name`` key.
Args:
env (EnvBase): An environment created via self.get_env_fun
Examples:
>>> print(task.observation_spec(env))
CompositeSpec(
agents: CompositeSpec(
observation: CompositeSpec(
image: UnboundedDiscreteTensorSpec(
shape=torch.Size([8, 88, 88, 3]),
space=ContinuousBox(
low=Tensor(shape=torch.Size([8, 88, 88, 3]), device=cpu, dtype=torch.int64, contiguous=True),
high=Tensor(shape=torch.Size([8, 88, 88, 3]), device=cpu, dtype=torch.int64, contiguous=True)),
device=cpu,
dtype=torch.uint8,
domain=discrete),
array: UnboundedContinuousTensorSpec(
shape=torch.Size([8, 3]),
space=None,
device=cpu,
dtype=torch.float32,
domain=continuous), device=cpu, shape=torch.Size([8])), device=cpu, shape=torch.Size([8])), device=cpu, shape=torch.Size([]))
"""
raise NotImplementedError

Expand Down

0 comments on commit 0a28a16

Please sign in to comment.