Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jun 8, 2024
1 parent 0a28a16 commit 82c383a
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 22 deletions.
12 changes: 12 additions & 0 deletions benchmarl/conf/task/vmas/discovery.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
defaults:
- _self_
- vmas_discovery_config

max_steps: 100
n_agents: 5
n_targets: 7
lidar_range: 0.35
covering_range: 0.25
agents_per_target: 2
targets_respawn: True
shared_reward: True
3 changes: 2 additions & 1 deletion benchmarl/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .pettingzoo.waterworld import TaskConfig as WaterworldConfig

from .vmas.balance import TaskConfig as BalanceConfig
from .vmas.discovery import TaskConfig as DiscoveryConfig
from .vmas.dispersion import TaskConfig as DispersionConfig
from .vmas.dropout import TaskConfig as DropoutConfig
from .vmas.give_way import TaskConfig as GiveWayConfig
Expand All @@ -52,7 +53,6 @@
from .vmas.wheel import TaskConfig as WheelConfig
from .vmas.wind_flocking import TaskConfig as WindFlockingConfig


# This is a registry mapping task config schemas names to their python dataclass
# It is used by hydra to validate loaded configs.
# You will see the "envname_taskname_config" strings in the hydra defaults at the top of yaml files.
Expand All @@ -68,6 +68,7 @@
"vmas_give_way_config": GiveWayConfig,
"vmas_wind_flocking_config": WindFlockingConfig,
"vmas_dropout_config": DropoutConfig,
"vmas_discovery_config": DiscoveryConfig,
"vmas_simple_adversary_config": VmasSimpleAdversaryConfig,
"vmas_simple_crypto_config": VmasSimpleCryptoConfig,
"vmas_simple_push_config": VmasSimplePushConfig,
Expand Down
36 changes: 20 additions & 16 deletions benchmarl/environments/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import importlib
import os
import os.path as osp
import warnings
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
Expand All @@ -23,22 +24,25 @@
def _load_config(name: str, config: Dict[str, Any]):
if not name.endswith(".py"):
name += ".py"

environemnt_name, task_name = name.split("/")
pathname = None
for dirpath, _, filenames in os.walk(osp.dirname(__file__)):
if pathname is None:
for filename in filenames:
if filename == name:
pathname = os.path.join(dirpath, filename)
break

if pathname is None:
raise ValueError(f"Task {name} not found.")

spec = importlib.util.spec_from_file_location("", pathname)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.TaskConfig(**config).__dict__
for dirpath, _, filenames in os.walk(
Path(osp.dirname(__file__)) / environemnt_name
):
if task_name in filenames:
pathname = os.path.join(dirpath, task_name)
break

if pathname is not None:
spec = importlib.util.spec_from_file_location("", pathname)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.TaskConfig(**config).__dict__
else:
warnings.warn(
"TaskConfig python dataclass not foud, task is being loaded without type checks"
)
return config


class Task(Enum):
Expand Down Expand Up @@ -302,7 +306,7 @@ def __str__(self):
@staticmethod
def _load_from_yaml(name: str) -> Dict[str, Any]:
yaml_path = Path(__file__).parent.parent / "conf" / "task" / f"{name}.yaml"
return _read_yaml_config(str(yaml_path.resolve()))
return _load_config(name, _read_yaml_config(str(yaml_path.resolve())))

def get_from_yaml(self, path: Optional[str] = None) -> Task:
"""
Expand Down
1 change: 1 addition & 0 deletions benchmarl/environments/pettingzoo/simple_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
class TaskConfig:
task: str = MISSING
max_cycles: int = MISSING
local_ratio: float = MISSING
9 changes: 9 additions & 0 deletions benchmarl/environments/vmas/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,18 @@ class VmasTask(Task):
REVERSE_TRANSPORT = None
WHEEL = None
DISPERSION = None
MULTI_GIVE_WAY = None
DROPOUT = None
GIVE_WAY = None
WIND_FLOCKING = None
PASSAGE = None
JOINT_PASSAGE = None
JOINT_PASSAGE_SIZE = None
BALL_PASSAGE = None
BALL_TRAJECTORY = None
BUZZ_WIRE = None
FLOCKING = None
DISCOVERY = None
SIMPLE_ADVERSARY = None
SIMPLE_CRYPTO = None
SIMPLE_PUSH = None
Expand Down
19 changes: 19 additions & 0 deletions benchmarl/environments/vmas/discovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from dataclasses import dataclass, MISSING


@dataclass
class TaskConfig:
max_steps: int = MISSING
n_agents: int = MISSING
n_targets: int = MISSING
lidar_range: float = MISSING
covering_range: float = MISSING
agents_per_target: int = MISSING
targets_respawn: bool = MISSING
shared_reward: bool = MISSING
10 changes: 9 additions & 1 deletion benchmarl/hydra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from benchmarl.algorithms.common import AlgorithmConfig
from benchmarl.environments import Task, task_config_registry
from benchmarl.environments.common import _load_config
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models import model_config_registry
from benchmarl.models.common import ModelConfig, parse_model_config, SequenceModelConfig
Expand Down Expand Up @@ -56,7 +57,14 @@ def load_task_config_from_hydra(cfg: DictConfig, task_name: str) -> Task:
"""
return task_config_registry[task_name].update_config(
OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
_load_config(
task_name,
OmegaConf.to_container(
cfg,
resolve=True,
throw_on_missing=True,
),
)
)


Expand Down
36 changes: 32 additions & 4 deletions test/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,24 @@
# LICENSE file in the root directory of this source tree.
#

import pytest
import contextlib

from benchmarl.environments import Task, task_config_registry
import pytest
from benchmarl.environments import _task_class_registry, Task, task_config_registry
from benchmarl.hydra_config import load_task_config_from_hydra
from hydra import compose, initialize


@pytest.mark.parametrize("task_name", task_config_registry.keys())
def test_loading_tasks(task_name):
task_dataclasses_names = list(_task_class_registry.keys())
config_task_name = task_name.replace("/", "_")
task_has_dataclass = False
for task_dataclass_name in task_dataclasses_names:
if config_task_name in task_dataclass_name:
task_has_dataclass = True
break

with initialize(version_base=None, config_path="../benchmarl/conf"):
cfg = compose(
config_name="config",
Expand All @@ -23,5 +32,24 @@ def test_loading_tasks(task_name):
return_hydra_config=True,
)
task_name_hydra = cfg.hydra.runtime.choices.task
task: Task = load_task_config_from_hydra(cfg.task, task_name=task_name_hydra)
assert task == task_config_registry[task_name].get_from_yaml()
assert task_name_hydra == task_name

warn_message = "TaskConfig python dataclass not foud, task is being loaded without type checks"

with (
pytest.warns(match=warn_message)
if not task_has_dataclass
else contextlib.nullcontext()
):
task: Task = load_task_config_from_hydra(
cfg.task, task_name=task_name_hydra
)

with (
pytest.warns(match=warn_message)
if not task_has_dataclass
else contextlib.nullcontext()
):
task_from_yaml: Task = task_config_registry[task_name].get_from_yaml()

assert task == task_from_yaml

0 comments on commit 82c383a

Please sign in to comment.