diff --git a/README.md b/README.md index e63cba6d..e99687fb 100644 --- a/README.md +++ b/README.md @@ -253,7 +253,7 @@ They differ based on many aspects, here is a table with the current environments > BenchMARL uses the [TorchRL MARL API](https://github.com/pytorch/rl/issues/1463) for grouping agents. > In competitive environments like MPE, for example, teams will be in different groups. Each group has its own loss, > models, buffers, and so on. Parameter sharing options refer to sharing within the group. See the example on [creating -> a custom algorithm](examples/extending/algorithm/custom_algorithm.py) for more info. +> a custom algorithm](examples/extending/algorithm/algorithms/customalgorithm.py) for more info. **Models**. Models are neural networks used to process data. They can be used as actors (policies) or, when requested, as critics. We provide a set of base models (layers) and a SequenceModel to concatenate diff --git a/benchmarl/__init__.py b/benchmarl/__init__.py index 826a1e98..d5161694 100644 --- a/benchmarl/__init__.py +++ b/benchmarl/__init__.py @@ -35,6 +35,6 @@ def _load_hydra_schemas(): cs.store(name=f"{algo_name}_config", group="algorithm", node=algo_schema) # Load task schemas for task_schema_name, task_schema in _task_class_registry.items(): - cs.store(name=task_schema_name, group="task", node=task_schema) + cs.store(name=f"{task_schema_name}_config", group="task", node=task_schema) _load_hydra_schemas() diff --git a/benchmarl/algorithms/common.py b/benchmarl/algorithms/common.py index 28989199..e742d808 100644 --- a/benchmarl/algorithms/common.py +++ b/benchmarl/algorithms/common.py @@ -358,14 +358,15 @@ def get_from_yaml(cls, path: Optional[str] = None): Returns: the loaded AlgorithmConfig """ + if path is None: - return cls( - **AlgorithmConfig._load_from_yaml( - name=cls.associated_class().__name__, - ) + config = AlgorithmConfig._load_from_yaml( + name=cls.associated_class().__name__ ) + else: - return cls(**_read_yaml_config(path)) + config = _read_yaml_config(path) + return cls(**config) @staticmethod @abstractmethod diff --git a/benchmarl/conf/task/pettingzoo/multiwalker.yaml b/benchmarl/conf/task/pettingzoo/multiwalker.yaml index ca6d5299..8b8f05ae 100644 --- a/benchmarl/conf/task/pettingzoo/multiwalker.yaml +++ b/benchmarl/conf/task/pettingzoo/multiwalker.yaml @@ -1,6 +1,6 @@ defaults: - - _self_ - pettingzoo_multiwalker_config + - _self_ task: "multiwalker_v9" # number of bipedal walker agents in environment diff --git a/benchmarl/conf/task/pettingzoo/simple_adversary.yaml b/benchmarl/conf/task/pettingzoo/simple_adversary.yaml index a0ebe319..fc17797f 100644 --- a/benchmarl/conf/task/pettingzoo/simple_adversary.yaml +++ b/benchmarl/conf/task/pettingzoo/simple_adversary.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - pettingzoo_simple_adversary_config + - _self_ + task: "simple_adversary_v3" N: 2 diff --git a/benchmarl/conf/task/pettingzoo/simple_crypto.yaml b/benchmarl/conf/task/pettingzoo/simple_crypto.yaml index cd0dc290..2c747e64 100644 --- a/benchmarl/conf/task/pettingzoo/simple_crypto.yaml +++ b/benchmarl/conf/task/pettingzoo/simple_crypto.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - pettingzoo_simple_crypto_config + - _self_ + task: "simple_crypto_v3" max_cycles: 100 diff --git a/benchmarl/conf/task/pettingzoo/simple_push.yaml b/benchmarl/conf/task/pettingzoo/simple_push.yaml index 98ad5836..91d6ffed 100644 --- a/benchmarl/conf/task/pettingzoo/simple_push.yaml +++ b/benchmarl/conf/task/pettingzoo/simple_push.yaml @@ -1,6 +1,7 @@ defaults: + - pettingzoo_simple_push_config - _self_ - - pettingzoo_simple_push_config + task: "simple_push_v3" max_cycles: 100 diff --git a/benchmarl/conf/task/pettingzoo/simple_reference.yaml b/benchmarl/conf/task/pettingzoo/simple_reference.yaml index b4734c2b..72e68a07 100644 --- a/benchmarl/conf/task/pettingzoo/simple_reference.yaml +++ b/benchmarl/conf/task/pettingzoo/simple_reference.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - pettingzoo_simple_reference_config + - _self_ + task: "simple_reference_v3" max_cycles: 100 diff --git a/benchmarl/conf/task/pettingzoo/simple_speaker_listener.yaml b/benchmarl/conf/task/pettingzoo/simple_speaker_listener.yaml index bcb48d9d..1b5f8d7c 100644 --- a/benchmarl/conf/task/pettingzoo/simple_speaker_listener.yaml +++ b/benchmarl/conf/task/pettingzoo/simple_speaker_listener.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - pettingzoo_simple_speaker_listener_config + - _self_ + task: "simple_speaker_listener_v4" max_cycles: 100 diff --git a/benchmarl/conf/task/pettingzoo/simple_spread.yaml b/benchmarl/conf/task/pettingzoo/simple_spread.yaml index f4d4bd75..a6f1b663 100644 --- a/benchmarl/conf/task/pettingzoo/simple_spread.yaml +++ b/benchmarl/conf/task/pettingzoo/simple_spread.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - pettingzoo_simple_spread_config + - _self_ + task: "simple_spread_v3" max_cycles: 100 diff --git a/benchmarl/conf/task/pettingzoo/simple_tag.yaml b/benchmarl/conf/task/pettingzoo/simple_tag.yaml index 1e71b47f..97374bba 100644 --- a/benchmarl/conf/task/pettingzoo/simple_tag.yaml +++ b/benchmarl/conf/task/pettingzoo/simple_tag.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - pettingzoo_simple_tag_config + - _self_ + task: "simple_tag_v3" num_good: 2 diff --git a/benchmarl/conf/task/pettingzoo/simple_world_comm.yaml b/benchmarl/conf/task/pettingzoo/simple_world_comm.yaml index 0ef8c4da..32053c79 100644 --- a/benchmarl/conf/task/pettingzoo/simple_world_comm.yaml +++ b/benchmarl/conf/task/pettingzoo/simple_world_comm.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - pettingzoo_simple_world_comm_config + - _self_ + task: "simple_world_comm_v3" num_good: 2 diff --git a/benchmarl/conf/task/pettingzoo/waterworld.yaml b/benchmarl/conf/task/pettingzoo/waterworld.yaml index cc570761..18fae3d3 100644 --- a/benchmarl/conf/task/pettingzoo/waterworld.yaml +++ b/benchmarl/conf/task/pettingzoo/waterworld.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - pettingzoo_waterworld_config + - _self_ + task: "waterworld_v4" max_cycles: 500 diff --git a/benchmarl/conf/task/vmas/balance.yaml b/benchmarl/conf/task/vmas/balance.yaml index e7614ce8..4aa23a09 100644 --- a/benchmarl/conf/task/vmas/balance.yaml +++ b/benchmarl/conf/task/vmas/balance.yaml @@ -1,7 +1,6 @@ defaults: - - _self_ - vmas_balance_config - + - _self_ max_steps: 100 n_agents: 4 diff --git a/benchmarl/conf/task/vmas/ball_passage.yaml b/benchmarl/conf/task/vmas/ball_passage.yaml index 26cbf773..9001a1d6 100644 --- a/benchmarl/conf/task/vmas/ball_passage.yaml +++ b/benchmarl/conf/task/vmas/ball_passage.yaml @@ -1,6 +1,6 @@ defaults: - - _self_ - vmas_ball_passage_config + - _self_ max_steps: 500 n_passages: 1 diff --git a/benchmarl/conf/task/vmas/ball_trajectory.yaml b/benchmarl/conf/task/vmas/ball_trajectory.yaml index 5babd799..d42a2509 100644 --- a/benchmarl/conf/task/vmas/ball_trajectory.yaml +++ b/benchmarl/conf/task/vmas/ball_trajectory.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_ball_trajectory_config + - _self_ + max_steps: 100 joints: True diff --git a/benchmarl/conf/task/vmas/buzz_wire.yaml b/benchmarl/conf/task/vmas/buzz_wire.yaml index 7dc696b3..5cc5fb06 100644 --- a/benchmarl/conf/task/vmas/buzz_wire.yaml +++ b/benchmarl/conf/task/vmas/buzz_wire.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_buzz_wire_config + - _self_ + max_steps: 100 random_start_angle: True diff --git a/benchmarl/conf/task/vmas/discovery.yaml b/benchmarl/conf/task/vmas/discovery.yaml index 129ccf51..a0b443c4 100644 --- a/benchmarl/conf/task/vmas/discovery.yaml +++ b/benchmarl/conf/task/vmas/discovery.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_discovery_config + - _self_ + max_steps: 100 n_agents: 5 diff --git a/benchmarl/conf/task/vmas/dispersion.yaml b/benchmarl/conf/task/vmas/dispersion.yaml index 63f9adc0..c2ae15df 100644 --- a/benchmarl/conf/task/vmas/dispersion.yaml +++ b/benchmarl/conf/task/vmas/dispersion.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_dispersion_config + - _self_ + max_steps: 100 diff --git a/benchmarl/conf/task/vmas/dropout.yaml b/benchmarl/conf/task/vmas/dropout.yaml index e7e50533..ad3a39aa 100644 --- a/benchmarl/conf/task/vmas/dropout.yaml +++ b/benchmarl/conf/task/vmas/dropout.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_dropout_config + - _self_ + max_steps: 100 diff --git a/benchmarl/conf/task/vmas/flocking.yaml b/benchmarl/conf/task/vmas/flocking.yaml index dafc0cc4..e27d9905 100644 --- a/benchmarl/conf/task/vmas/flocking.yaml +++ b/benchmarl/conf/task/vmas/flocking.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_flocking_config + - _self_ + max_steps: 100 n_agents: 4 diff --git a/benchmarl/conf/task/vmas/give_way.yaml b/benchmarl/conf/task/vmas/give_way.yaml index 4f1b29fe..44c9dc32 100644 --- a/benchmarl/conf/task/vmas/give_way.yaml +++ b/benchmarl/conf/task/vmas/give_way.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_give_way_config + - _self_ + max_steps: 100 mirror_passage: False diff --git a/benchmarl/conf/task/vmas/joint_passage.yaml b/benchmarl/conf/task/vmas/joint_passage.yaml index 918840b4..da4588b4 100644 --- a/benchmarl/conf/task/vmas/joint_passage.yaml +++ b/benchmarl/conf/task/vmas/joint_passage.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_joint_passage_config + - _self_ + max_steps: 500 n_passages: 1 diff --git a/benchmarl/conf/task/vmas/joint_passage_size.yaml b/benchmarl/conf/task/vmas/joint_passage_size.yaml index ae7fe78e..13234c8d 100644 --- a/benchmarl/conf/task/vmas/joint_passage_size.yaml +++ b/benchmarl/conf/task/vmas/joint_passage_size.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_joint_passage_size_config + - _self_ + max_steps: 500 n_passages: 3 diff --git a/benchmarl/conf/task/vmas/multi_give_way.yaml b/benchmarl/conf/task/vmas/multi_give_way.yaml index 8c57ba7a..e23eee60 100644 --- a/benchmarl/conf/task/vmas/multi_give_way.yaml +++ b/benchmarl/conf/task/vmas/multi_give_way.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_multi_give_way_config + - _self_ + max_steps: 200 agent_collision_penalty: -0.1 diff --git a/benchmarl/conf/task/vmas/navigation.yaml b/benchmarl/conf/task/vmas/navigation.yaml index 3a4b797a..b9d2e2d0 100644 --- a/benchmarl/conf/task/vmas/navigation.yaml +++ b/benchmarl/conf/task/vmas/navigation.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_navigation_config + - _self_ + max_steps: 100 diff --git a/benchmarl/conf/task/vmas/passage.yaml b/benchmarl/conf/task/vmas/passage.yaml index af625b49..e3563f3a 100644 --- a/benchmarl/conf/task/vmas/passage.yaml +++ b/benchmarl/conf/task/vmas/passage.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_passage_config + - _self_ + max_steps: 500 n_passages: 1 diff --git a/benchmarl/conf/task/vmas/reverse_transport.yaml b/benchmarl/conf/task/vmas/reverse_transport.yaml index db09c199..3869832d 100644 --- a/benchmarl/conf/task/vmas/reverse_transport.yaml +++ b/benchmarl/conf/task/vmas/reverse_transport.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_reverse_transport_config + - _self_ + max_steps: 100 n_agents: 4 diff --git a/benchmarl/conf/task/vmas/sampling.yaml b/benchmarl/conf/task/vmas/sampling.yaml index 3d29f503..4ed11f9b 100644 --- a/benchmarl/conf/task/vmas/sampling.yaml +++ b/benchmarl/conf/task/vmas/sampling.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_sampling_config + - _self_ + max_steps: 100 diff --git a/benchmarl/conf/task/vmas/simple_adversary.yaml b/benchmarl/conf/task/vmas/simple_adversary.yaml index a78f9467..76e0108e 100644 --- a/benchmarl/conf/task/vmas/simple_adversary.yaml +++ b/benchmarl/conf/task/vmas/simple_adversary.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_simple_adversary_config + - _self_ + max_steps: 100 n_agents: 3 diff --git a/benchmarl/conf/task/vmas/simple_crypto.yaml b/benchmarl/conf/task/vmas/simple_crypto.yaml index 4019dd41..0abe6ddc 100644 --- a/benchmarl/conf/task/vmas/simple_crypto.yaml +++ b/benchmarl/conf/task/vmas/simple_crypto.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_simple_crypto_config + - _self_ + max_steps: 100 dim_c: 4 diff --git a/benchmarl/conf/task/vmas/simple_push.yaml b/benchmarl/conf/task/vmas/simple_push.yaml index f0945f0f..b7c8bf48 100644 --- a/benchmarl/conf/task/vmas/simple_push.yaml +++ b/benchmarl/conf/task/vmas/simple_push.yaml @@ -1,5 +1,6 @@ defaults: - - _self_ - vmas_simple_push_config + - _self_ + max_steps: 100 diff --git a/benchmarl/conf/task/vmas/simple_reference.yaml b/benchmarl/conf/task/vmas/simple_reference.yaml index 459a0c61..52cc9e9b 100644 --- a/benchmarl/conf/task/vmas/simple_reference.yaml +++ b/benchmarl/conf/task/vmas/simple_reference.yaml @@ -1,5 +1,6 @@ defaults: - - _self_ - vmas_simple_reference_config + - _self_ + max_steps: 100 diff --git a/benchmarl/conf/task/vmas/simple_speaker_listener.yaml b/benchmarl/conf/task/vmas/simple_speaker_listener.yaml index 9d25fe89..1379702a 100644 --- a/benchmarl/conf/task/vmas/simple_speaker_listener.yaml +++ b/benchmarl/conf/task/vmas/simple_speaker_listener.yaml @@ -1,5 +1,6 @@ defaults: - - _self_ - vmas_simple_speaker_listener_config + - _self_ + max_steps: 100 diff --git a/benchmarl/conf/task/vmas/simple_spread.yaml b/benchmarl/conf/task/vmas/simple_spread.yaml index 1411f1e8..6cc67162 100644 --- a/benchmarl/conf/task/vmas/simple_spread.yaml +++ b/benchmarl/conf/task/vmas/simple_spread.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_simple_spread_config + - _self_ + max_steps: 100 n_agents: 3 diff --git a/benchmarl/conf/task/vmas/simple_tag.yaml b/benchmarl/conf/task/vmas/simple_tag.yaml index 59022523..de597266 100644 --- a/benchmarl/conf/task/vmas/simple_tag.yaml +++ b/benchmarl/conf/task/vmas/simple_tag.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_simple_tag_config + - _self_ + max_steps: 100 num_good_agents: 1 diff --git a/benchmarl/conf/task/vmas/simple_world_comm.yaml b/benchmarl/conf/task/vmas/simple_world_comm.yaml index 30d40d34..736fe835 100644 --- a/benchmarl/conf/task/vmas/simple_world_comm.yaml +++ b/benchmarl/conf/task/vmas/simple_world_comm.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_simple_world_comm_config + - _self_ + max_steps: 100 num_good_agents: 2 diff --git a/benchmarl/conf/task/vmas/transport.yaml b/benchmarl/conf/task/vmas/transport.yaml index 7013792b..1431b8e2 100644 --- a/benchmarl/conf/task/vmas/transport.yaml +++ b/benchmarl/conf/task/vmas/transport.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_transport_config + - _self_ + max_steps: 100 diff --git a/benchmarl/conf/task/vmas/wheel.yaml b/benchmarl/conf/task/vmas/wheel.yaml index a3aa706c..6da697cb 100644 --- a/benchmarl/conf/task/vmas/wheel.yaml +++ b/benchmarl/conf/task/vmas/wheel.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_wheel_config + - _self_ + max_steps: 100 n_agents: 4 diff --git a/benchmarl/conf/task/vmas/wind_flocking.yaml b/benchmarl/conf/task/vmas/wind_flocking.yaml index 6d2a0ff8..c036d769 100644 --- a/benchmarl/conf/task/vmas/wind_flocking.yaml +++ b/benchmarl/conf/task/vmas/wind_flocking.yaml @@ -1,6 +1,7 @@ defaults: - - _self_ - vmas_wind_flocking_config + - _self_ + max_steps: 100 horizon: 100 diff --git a/benchmarl/environments/__init__.py b/benchmarl/environments/__init__.py index 3fa78b0a..4648cc0b 100644 --- a/benchmarl/environments/__init__.py +++ b/benchmarl/environments/__init__.py @@ -4,103 +4,38 @@ # LICENSE file in the root directory of this source tree. # -from .common import Task +from .common import _get_task_config_class, Task from .meltingpot.common import MeltingPotTask from .pettingzoo.common import PettingZooTask from .smacv2.common import Smacv2Task from .vmas.common import VmasTask +# The enum classes for the environments available. +# This is the only object in this file you need to modify when adding a new environment. +tasks = [VmasTask, Smacv2Task, PettingZooTask, MeltingPotTask] + # This is a registry mapping "envname/task_name" to the EnvNameTask.TASK_NAME enum -# It is used by automatically load task enums from yaml files +# It is used by automatically load task enums from yaml files. +# It is populated automatically, do not modify. task_config_registry = {} -for env in [VmasTask, Smacv2Task, PettingZooTask, MeltingPotTask]: - env_config_registry = { - f"{env.env_name()}/{task.name.lower()}": task for task in env - } - task_config_registry.update(env_config_registry) - -from .pettingzoo.multiwalker import TaskConfig as MultiwalkerConfig -from .pettingzoo.simple_adversary import TaskConfig as SimpleAdversaryConfig -from .pettingzoo.simple_crypto import TaskConfig as SimpleCryptoConfig -from .pettingzoo.simple_push import TaskConfig as SimplePushConfig -from .pettingzoo.simple_reference import TaskConfig as SimpleReferenceConfig -from .pettingzoo.simple_speaker_listener import ( - TaskConfig as SimpleSpeakerListenerConfig, -) -from .pettingzoo.simple_spread import TaskConfig as SimpleSpreadConfig -from .pettingzoo.simple_tag import TaskConfig as SimpleTagConfig -from .pettingzoo.simple_world_comm import TaskConfig as SimpleWorldComm -from .pettingzoo.waterworld import TaskConfig as WaterworldConfig +# This is a registry mapping "envname_taskname" to the TaskConfig python dataclass of the task. +# It is used by hydra to validate loaded configs. +# You will see the "envname_taskname" strings in the hydra defaults at the top of yaml files. +# This is optional and, if a task does not possess an associated TaskConfig, this entry will be simply skipped. +# It is populated automatically, do not modify. +_task_class_registry = {} -from .vmas.balance import TaskConfig as BalanceConfig -from .vmas.ball_passage import TaskConfig as BallPassageConfig -from .vmas.ball_trajectory import TaskConfig as BallTrajectoryConfig -from .vmas.buzz_wire import TaskConfig as BuzzWireConfig -from .vmas.discovery import TaskConfig as DiscoveryConfig -from .vmas.dispersion import TaskConfig as DispersionConfig -from .vmas.dropout import TaskConfig as DropoutConfig -from .vmas.flocking import TaskConfig as FlockingConfig -from .vmas.give_way import TaskConfig as GiveWayConfig -from .vmas.joint_passage import TaskConfig as JointPassageConfig -from .vmas.joint_passage_size import TaskConfig as JointPassageSizeConfig -from .vmas.multi_give_way import TaskConfig as MultiGiveWayConfig -from .vmas.navigation import TaskConfig as NavigationConfig -from .vmas.passage import TaskConfig as PassageConfig -from .vmas.reverse_transport import TaskConfig as ReverseTransportConfig -from .vmas.sampling import TaskConfig as SamplingConfig -from .vmas.simple_adversary import TaskConfig as VmasSimpleAdversaryConfig -from .vmas.simple_crypto import TaskConfig as VmasSimpleCryptoConfig -from .vmas.simple_push import TaskConfig as VmasSimplePushConfig -from .vmas.simple_reference import TaskConfig as VmasSimpleReferenceConfig -from .vmas.simple_speaker_listener import TaskConfig as VmasSimpleSpeakerListenerConfig -from .vmas.simple_spread import TaskConfig as VmasSimpleSpreadConfig -from .vmas.simple_tag import TaskConfig as VmasSimpleTagConfig -from .vmas.simple_world_comm import TaskConfig as VmasSimpleWorldComm -from .vmas.transport import TaskConfig as TransportConfig -from .vmas.wheel import TaskConfig as WheelConfig -from .vmas.wind_flocking import TaskConfig as WindFlockingConfig +# Automatic population of registries +for env in tasks: + env_config_registry = {} + environemnt_name = env.env_name() + for task in env: + task_name = task.name.lower() + full_task_name = f"{environemnt_name}/{task_name}" + env_config_registry[full_task_name] = task -# This is a registry mapping task config schemas names to their python dataclass -# It is used by hydra to validate loaded configs. -# You will see the "envname_taskname_config" strings in the hydra defaults at the top of yaml files. -# This feature is optional. -_task_class_registry = { - "vmas_balance_config": BalanceConfig, - "vmas_sampling_config": SamplingConfig, - "vmas_navigation_config": NavigationConfig, - "vmas_transport_config": TransportConfig, - "vmas_reverse_transport_config": ReverseTransportConfig, - "vmas_wheel_config": WheelConfig, - "vmas_dispersion_config": DispersionConfig, - "vmas_give_way_config": GiveWayConfig, - "vmas_multi_give_way_config": MultiGiveWayConfig, - "vmas_passage_config": PassageConfig, - "vmas_joint_passage_config": JointPassageConfig, - "vmas_joint_passage_size_config": JointPassageSizeConfig, - "vmas_ball_passage_config": BallPassageConfig, - "vmas_buzz_wire_config": BuzzWireConfig, - "vmas_ball_trajectory_config": BallTrajectoryConfig, - "vmas_flocking_config": FlockingConfig, - "vmas_wind_flocking_config": WindFlockingConfig, - "vmas_dropout_config": DropoutConfig, - "vmas_discovery_config": DiscoveryConfig, - "vmas_simple_adversary_config": VmasSimpleAdversaryConfig, - "vmas_simple_crypto_config": VmasSimpleCryptoConfig, - "vmas_simple_push_config": VmasSimplePushConfig, - "vmas_simple_reference_config": VmasSimpleReferenceConfig, - "vmas_simple_speaker_listener_config": VmasSimpleSpeakerListenerConfig, - "vmas_simple_spread_config": VmasSimpleSpreadConfig, - "vmas_simple_tag_config": VmasSimpleTagConfig, - "vmas_simple_world_comm_config": VmasSimpleWorldComm, - "pettingzoo_multiwalker_config": MultiwalkerConfig, - "pettingzoo_waterworld_config": WaterworldConfig, - "pettingzoo_simple_adversary_config": SimpleAdversaryConfig, - "pettingzoo_simple_crypto_config": SimpleCryptoConfig, - "pettingzoo_simple_push_config": SimplePushConfig, - "pettingzoo_simple_reference_config": SimpleReferenceConfig, - "pettingzoo_simple_speaker_listener_config": SimpleSpeakerListenerConfig, - "pettingzoo_simple_spread_config": SimpleSpreadConfig, - "pettingzoo_simple_tag_config": SimpleTagConfig, - "pettingzoo_simple_world_comm_config": SimpleWorldComm, -} + task_config_class = _get_task_config_class(environemnt_name, task_name) + if task_config_class is not None: + _task_class_registry[full_task_name.replace("/", "_")] = task_config_class + task_config_registry.update(env_config_registry) diff --git a/benchmarl/environments/common.py b/benchmarl/environments/common.py index 69130dfe..1d5e3994 100644 --- a/benchmarl/environments/common.py +++ b/benchmarl/environments/common.py @@ -7,8 +7,8 @@ from __future__ import annotations import importlib -import os -import os.path as osp + +import warnings from enum import Enum from pathlib import Path from typing import Any, Callable, Dict, List, Optional @@ -20,25 +20,33 @@ from benchmarl.utils import _read_yaml_config, DEVICE_TYPING -def _load_config(name: str, config: Dict[str, Any]): - if not name.endswith(".py"): - name += ".py" +def _type_check_task_config( + environemnt_name: str, + task_name: str, + config: Dict[str, Any], + warn_on_missing_dataclass: bool = True, +): + + task_config_class = _get_task_config_class(environemnt_name, task_name) - pathname = None - for dirpath, _, filenames in os.walk(osp.dirname(__file__)): - if pathname is None: - for filename in filenames: - if filename == name: - pathname = os.path.join(dirpath, filename) - break + if task_config_class is not None: + return task_config_class(**config).__dict__ + else: + if warn_on_missing_dataclass: + warnings.warn( + "TaskConfig python dataclass not foud, task is being loaded without type checks" + ) + return config - if pathname is None: - raise ValueError(f"Task {name} not found.") - spec = importlib.util.spec_from_file_location("", pathname) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module.TaskConfig(**config).__dict__ +def _get_task_config_class(environemnt_name: str, task_name: str): + try: + module = importlib.import_module( + f"{'.'.join(__name__.split('.')[:-1])}.{environemnt_name}.{task_name}" + ) + return module.TaskConfig + except ModuleNotFoundError: + return None class Task(Enum): @@ -314,10 +322,12 @@ def get_from_yaml(self, path: Optional[str] = None) -> Task: Returns: the task with the loaded config """ + environment_name = self.env_name() + task_name = self.name.lower() + full_name = str(Path(environment_name) / Path(task_name)) if path is None: - task_name = self.name.lower() - return self.update_config( - Task._load_from_yaml(str(Path(self.env_name()) / Path(task_name))) - ) + config = Task._load_from_yaml(full_name) else: - return self.update_config(**_read_yaml_config(path)) + config = _read_yaml_config(path) + config = _type_check_task_config(environment_name, task_name, config) + return self.update_config(config) diff --git a/benchmarl/hydra_config.py b/benchmarl/hydra_config.py index f59a72fb..f83c25ad 100644 --- a/benchmarl/hydra_config.py +++ b/benchmarl/hydra_config.py @@ -4,9 +4,11 @@ # LICENSE file in the root directory of this source tree. # import importlib +from dataclasses import is_dataclass from benchmarl.algorithms.common import AlgorithmConfig from benchmarl.environments import Task, task_config_registry +from benchmarl.environments.common import _type_check_task_config from benchmarl.experiment import Experiment, ExperimentConfig from benchmarl.models import model_config_registry from benchmarl.models.common import ModelConfig, parse_model_config, SequenceModelConfig @@ -58,9 +60,14 @@ def load_task_config_from_hydra(cfg: DictConfig, task_name: str) -> Task: :class:`~benchmarl.environments.Task` """ - return task_config_registry[task_name].update_config( - OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) - ) + environment_name, inner_task_name = task_name.split("/") + cfg_dict_checked = OmegaConf.to_object(cfg) + if is_dataclass(cfg_dict_checked): + cfg_dict_checked = cfg_dict_checked.__dict__ + cfg_dict_checked = _type_check_task_config( + environment_name, inner_task_name, cfg_dict_checked + ) # Only needed for the warning + return task_config_registry[task_name].update_config(cfg_dict_checked) def load_experiment_config_from_hydra(cfg: DictConfig) -> ExperimentConfig: diff --git a/benchmarl/models/common.py b/benchmarl/models/common.py index 7b332099..892acb78 100644 --- a/benchmarl/models/common.py +++ b/benchmarl/models/common.py @@ -295,8 +295,7 @@ def _load_from_yaml(name: str) -> Dict[str, Any]: / "layers" / f"{name.lower()}.yaml" ) - cfg = _read_yaml_config(str(yaml_path.resolve())) - return parse_model_config(cfg) + return _read_yaml_config(str(yaml_path.resolve())) @classmethod def get_from_yaml(cls, path: Optional[str] = None): @@ -311,13 +310,11 @@ def get_from_yaml(cls, path: Optional[str] = None): Returns: the loaded AlgorithmConfig """ if path is None: - return cls( - **ModelConfig._load_from_yaml( - name=cls.associated_class().__name__, - ) - ) + config = ModelConfig._load_from_yaml(name=cls.associated_class().__name__) else: - return cls(**parse_model_config(_read_yaml_config(path))) + config = _read_yaml_config(path) + config = parse_model_config(config) + return cls(**config) @dataclass diff --git a/examples/extending/algorithm/README.md b/examples/extending/algorithm/README.md index 31d4a8df..d65888b5 100644 --- a/examples/extending/algorithm/README.md +++ b/examples/extending/algorithm/README.md @@ -1,21 +1,20 @@ # Creating a new algorithm -Here are the steps to create a new algorithm. You can find the custom IQL algorithm -created for this example in [`custom_agorithm.py`](custom_algorithm.py). +Here are the steps to create a new algorithm. 1. Create your `CustomAlgorithm` and `CustomAlgorithmConfig` following the example -in [`custom_agorithm.py`](custom_algorithm.py). These will be the algorithm code +in [`algorithms/customalgorithm.py`](algorithms/customalgorithm.py). These will be the algorithm code and an associated dataclass to validate loaded configs. 2. Create a `customalgorithm.yaml` with the configuration parameters you defined in your script. Make sure it has `customalgorithm_config` within its defaults at the top of the file to let hydra know which python dataclass it is -associated to. You can see [`customiqlalgorithm.yaml`](customiqlalgorithm.yaml) +associated to. You can see [`conf/algorithm/customalgorithm.yaml`](conf/algorithm/customalgorithm.yaml) for an example. 3. Place your algorithm script in [`benchmarl/algorithms`](../../../benchmarl/algorithms) and your config in [`benchmarl/conf/algorithm`](../../../benchmarl/conf/algorithm) (or any other place you want to override from) -4. Add `{"customagorithm": CustomAlgorithmConfig}` to the [`benchmarl.algorithms.algorithm_config_registry`](../../../benchmarl/algorithms/__init__.py) +4. Add `{"customalgorithm": CustomAlgorithmConfig}` to the [`benchmarl.algorithms.algorithm_config_registry`](../../../benchmarl/algorithms/__init__.py) 5. Load it with ```bash python benchmarl/run.py algorithm=customalgorithm task=... diff --git a/examples/extending/algorithm/custom_algorithm.py b/examples/extending/algorithm/algorithms/customalgorithm.py similarity index 98% rename from examples/extending/algorithm/custom_algorithm.py rename to examples/extending/algorithm/algorithms/customalgorithm.py index e35da86b..2afcf030 100644 --- a/examples/extending/algorithm/custom_algorithm.py +++ b/examples/extending/algorithm/algorithms/customalgorithm.py @@ -17,7 +17,7 @@ from torchrl.objectives import DQNLoss, LossModule, ValueEstimators -class CustomIqlAlgorithm(Algorithm): +class CustomAlgorithm(Algorithm): def __init__( self, delay_value: bool, loss_function: str, my_custom_arg: int, **kwargs ): @@ -213,7 +213,7 @@ def my_custom_method(self): @dataclass -class CustomIqlConfig(AlgorithmConfig): +class CustomAlgorithmConfig(AlgorithmConfig): # This is a class representing the configuration of your algorithm # It will be used to validate loaded configs, so that everytime you load this algorithm # we know exactly which and what parameters to expect with their types @@ -226,7 +226,7 @@ class CustomIqlConfig(AlgorithmConfig): @staticmethod def associated_class() -> Type[Algorithm]: # The associated algorithm class - return CustomIqlAlgorithm + return CustomAlgorithm @staticmethod def supports_continuous_actions() -> bool: diff --git a/examples/extending/algorithm/customiqlalgorithm.yaml b/examples/extending/algorithm/conf/algorithm/customalgorithm.yaml similarity index 100% rename from examples/extending/algorithm/customiqlalgorithm.yaml rename to examples/extending/algorithm/conf/algorithm/customalgorithm.yaml diff --git a/examples/extending/model/README.md b/examples/extending/model/README.md index c087aa29..79db9b71 100644 --- a/examples/extending/model/README.md +++ b/examples/extending/model/README.md @@ -4,16 +4,16 @@ Here are the steps to create a new model. 1. Create your `CustomModel` and `CustomModelConfig` following the example -in [`custom_model.py`](custom_model.py). These will be the model code +in [`models/custommodel.py`](models/custommodel.py). These will be the model code and an associated dataclass to validate loaded configs. 2. Create a `custommodel.yaml` with the configuration parameters you defined -in your script. Make sure it has a `name` entry equal to `custom_model` to let hydra know which python dataclass it is -associated to. You can see [`custommodel.yaml`](custommodel.yaml) +in your script. Make sure it has a `name` entry equal to `custommodel` to let hydra know which python dataclass it is +associated to. You can see [`conf/model/layers/custommodel.yaml`](conf/model/layers/custommodel.yaml) for an example. 3. Place your model script in [`benchmarl/models`](../../../benchmarl/models) and your config in [`benchmarl/conf/model/layers`](../../../benchmarl/conf/model/layers) (or any other place you want to override from) -4. Add `{"custom_model": CustomModelConfig}` to the [`benchmarl.models.model_config_registry`](../../../benchmarl/models/__init__.py) +4. Add `{"custommodel": CustomModelConfig}` to the [`benchmarl.models.model_config_registry`](../../../benchmarl/models/__init__.py) 5. Load it with ```bash python benchmarl/run.py model=layers/custommodel algorithm=... task=... diff --git a/examples/extending/model/custommodel.yaml b/examples/extending/model/conf/model/layers/custommodel.yaml similarity index 72% rename from examples/extending/model/custommodel.yaml rename to examples/extending/model/conf/model/layers/custommodel.yaml index 547ece39..90dda778 100644 --- a/examples/extending/model/custommodel.yaml +++ b/examples/extending/model/conf/model/layers/custommodel.yaml @@ -1,4 +1,4 @@ -name: custom_model +name: custommodel custom_param: 3 activation_class: torch.nn.Tanh diff --git a/examples/extending/model/custom_model.py b/examples/extending/model/models/custommodel.py similarity index 100% rename from examples/extending/model/custom_model.py rename to examples/extending/model/models/custommodel.py diff --git a/examples/extending/task/README.md b/examples/extending/task/README.md index bd596065..9c0d553d 100644 --- a/examples/extending/task/README.md +++ b/examples/extending/task/README.md @@ -1,23 +1,43 @@ +# Creating a new task -# Creating a new task (from a new environment) +In the following we will see how to: +- Create new tasks from a new environment +- Create new tasks from an existing environment -Here are the steps to create a new task. +## Creating new tasks from a new environment -1. Create your `CustomEnvTask` following the example in [`custom_task.py`](custom_task.py). -This is an enum with task entries and some abstract functions you need to implement. +Here are the steps to create a new task and a new environment. -2. Create a `customenv` folder with a yaml configuration file for each of your tasks. -You can see [`customenv`](customenv) for an example. +1. Create your `CustomEnvTask` following the example in [`environments/customenv/common.py`](environments/customenv/common.py). +This is an enum with task entries and some abstract functions you need to implement. The entries of this enum will be the +uppercase names of your tasks. +2. Create a `conf/task/customenv` folder with a yaml configuration file for each of your tasks. This folder will have a +yaml configuration for each task. You can see [`conf/task/customenv`](conf/task/customenv) for an example. 3. Place your task script in [`benchmarl/environments/customenv/common.py`](../../../benchmarl/environments) and -your config in [`benchmarl/conf/task`](../../../benchmarl/conf/task) (or any other place you want to +your config in [`benchmarl/conf/task/customenv`](../../../benchmarl/conf/task) (or any other place you want to override from). -4. Add `{"customenv/{task_name}": CustomEnvTask.TASK_NAME}` to the -[`benchmarl.environments.task_config_registry`](../../../benchmarl/environments/__init__.py) for all tasks. +4. Add `CustomEnvTask` to [`benchmarl.environments.tasks`](../../../benchmarl/environments/__init__.py) list. 5. Load it with ```bash -python benchmarl/run.py task=customenv/task_name algorithm=... +python benchmarl/run.py task=customenv/task_1 algorithm=... ``` 6. (Optional) You can create python dataclasses to use as schemas for your tasks -to validate their config. We are not going to illustrate this here, but if -you want to see an example, check out [`benchmarl/environments/vmas`](../../../benchmarl/environments/vmas). +to validate their config. This will allow to check the configuration entries and types for each task. +This step is optional and, if you skip it, everything will work (just without the task config being checked against python dataclasses). +To do it, just create `environments/customenv/taskname.py` for each task, with a `TaskConfig` object following the structure shown in +[`environments/customenv/task_1.py`](environments/customenv/task_1.py). In our example, `task_1` has such dataclass, while `task_2` +doesn't. The name of the python file has to be the name of the task in lower case. Then you need to tell hydra to use +this as a schema by adding `customenv_taskname_config` to the defaults at the top of the task yaml file. +See [`conf/task/customenv/task_1.yaml`](conf/task/customenv/task_1.yaml) for an example. + +## Creating new tasks from an existing environment + +Imagine we now have already in the library `customenv` with `task_1` and `task_2`. +To create a new task (e.g., `task_3`) in an existing environment , follow these steps: + +1. Add `TASK_3 = None` to `CustomEnvTask` in [`environments/customenv/common.py`](environments/customenv/common.py). +2. Add `task_3.yaml` to [`conf/task/customenv`](conf/task/customenv) + +3. (Optional) Add `task_3.py` to [`environments/customenv`](environments/customenv) and +the default `customenv_task_3_config` at the top of `task_3.yaml`. diff --git a/examples/extending/task/conf/task/customenv/task_1.yaml b/examples/extending/task/conf/task/customenv/task_1.yaml new file mode 100644 index 00000000..6d27b326 --- /dev/null +++ b/examples/extending/task/conf/task/customenv/task_1.yaml @@ -0,0 +1,6 @@ +defaults: + - customenv_task_1_config + - _self_ + +n_borks: 3 +win_on_dork: True diff --git a/examples/extending/task/customenv/task_2.yaml b/examples/extending/task/conf/task/customenv/task_2.yaml similarity index 100% rename from examples/extending/task/customenv/task_2.yaml rename to examples/extending/task/conf/task/customenv/task_2.yaml diff --git a/examples/extending/task/customenv/task_1.yaml b/examples/extending/task/customenv/task_1.yaml deleted file mode 100644 index f87c0fde..00000000 --- a/examples/extending/task/customenv/task_1.yaml +++ /dev/null @@ -1,3 +0,0 @@ - -n_borks: 3 -win_on_dork: True diff --git a/examples/extending/task/custom_task.py b/examples/extending/task/environments/customenv/common.py similarity index 93% rename from examples/extending/task/custom_task.py rename to examples/extending/task/environments/customenv/common.py index 336b3c41..cd2c3229 100644 --- a/examples/extending/task/custom_task.py +++ b/examples/extending/task/environments/customenv/common.py @@ -17,10 +17,10 @@ class CustomEnvTask(Task): # Your task names. - # Their config will be loaded from benchmarl/conf/task/customenv + # Their config will be loaded from conf/task/customenv - TASK_1 = None # Loaded automatically from benchmarl/conf/task/customenv/task_1 - TASK_2 = None # Loaded automatically from benchmarl/conf/task/customenv/task_2 + TASK_1 = None # Loaded automatically from conf/task/customenv/task_1 + TASK_2 = None # Loaded automatically from conf/task/customenv/task_2 def get_env_fun( self, diff --git a/examples/extending/task/environments/customenv/task_1.py b/examples/extending/task/environments/customenv/task_1.py new file mode 100644 index 00000000..0d0277f0 --- /dev/null +++ b/examples/extending/task/environments/customenv/task_1.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from dataclasses import dataclass, MISSING + + +@dataclass +class TaskConfig: + n_borks: int = MISSING + win_on_dork: bool = MISSING diff --git a/test/test_task.py b/test/test_task.py index 60890024..fc1660f6 100644 --- a/test/test_task.py +++ b/test/test_task.py @@ -4,15 +4,24 @@ # LICENSE file in the root directory of this source tree. # -import pytest +import contextlib -from benchmarl.environments import Task, task_config_registry +import pytest +from benchmarl.environments import _task_class_registry, Task, task_config_registry from benchmarl.hydra_config import load_task_config_from_hydra from hydra import compose, initialize @pytest.mark.parametrize("task_name", task_config_registry.keys()) def test_loading_tasks(task_name): + task_dataclasses_names = list(_task_class_registry.keys()) + config_task_name = task_name.replace("/", "_") + task_has_dataclass = False + for task_dataclass_name in task_dataclasses_names: + if config_task_name in task_dataclass_name: + task_has_dataclass = True + break + with initialize(version_base=None, config_path="../benchmarl/conf"): cfg = compose( config_name="config", @@ -23,5 +32,24 @@ def test_loading_tasks(task_name): return_hydra_config=True, ) task_name_hydra = cfg.hydra.runtime.choices.task - task: Task = load_task_config_from_hydra(cfg.task, task_name=task_name_hydra) - assert task == task_config_registry[task_name].get_from_yaml() + assert task_name_hydra == task_name + + warn_message = "TaskConfig python dataclass not foud, task is being loaded without type checks" + + with ( + pytest.warns(match=warn_message) + if not task_has_dataclass + else contextlib.nullcontext() + ): + task: Task = load_task_config_from_hydra( + cfg.task, task_name=task_name_hydra + ) + + with ( + pytest.warns(match=warn_message) + if not task_has_dataclass + else contextlib.nullcontext() + ): + task_from_yaml: Task = task_config_registry[task_name].get_from_yaml() + + assert task == task_from_yaml