Skip to content

Commit

Permalink
ensemble algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jan 7, 2025
1 parent 441d897 commit 8113962
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 0 deletions.
1 change: 1 addition & 0 deletions benchmarl/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#

from .common import Algorithm, AlgorithmConfig
from .ensemble import EnsembleAlgorithm, EnsembleAlgorithmConfig
from .iddpg import Iddpg, IddpgConfig
from .ippo import Ippo, IppoConfig
from .iql import Iql, IqlConfig
Expand Down
128 changes: 128 additions & 0 deletions benchmarl/algorithms/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


from dataclasses import dataclass
from typing import Dict, Iterable, Optional, Tuple, Type

from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule

from torchrl.objectives import LossModule

from benchmarl.algorithms.common import Algorithm, AlgorithmConfig

from benchmarl.models.common import ModelConfig


class EnsembleAlgorithm(Algorithm):
def __init__(self, algorithms_map, **kwargs):
super().__init__(**kwargs)
self.algorithms_map = algorithms_map

def _get_loss(
self, group: str, policy_for_loss: TensorDictModule, continuous: bool
) -> Tuple[LossModule, bool]:
return self.algorithms_map[group]._get_loss(group, policy_for_loss, continuous)

def _get_parameters(self, group: str, loss: LossModule) -> Dict[str, Iterable]:
return self.algorithms_map[group]._get_parameters(group, loss)

def _get_policy_for_loss(
self, group: str, model_config: ModelConfig, continuous: bool
) -> TensorDictModule:
return self.algorithms_map[group]._get_policy_for_loss(
group, model_config, continuous
)

def _get_policy_for_collection(
self, policy_for_loss: TensorDictModule, group: str, continuous: bool
) -> TensorDictModule:
return self.algorithms_map[group]._get_policy_for_collection(
policy_for_loss, group, continuous
)

def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
return self.algorithms_map[group].process_batch(group, batch)

def process_loss_vals(
self, group: str, loss_vals: TensorDictBase
) -> TensorDictBase:
return self.algorithms_map[group].process_loss_vals(group, loss_vals)


@dataclass
class EnsembleAlgorithmConfig(AlgorithmConfig):

algorithm_configs_map: Dict[str, AlgorithmConfig]

def __post_init__(self):
algorithm_configs = list(self.algorithm_configs_map.values())
self._on_policy = algorithm_configs[0].on_policy()

for algorithm_config in algorithm_configs[1:]:
if algorithm_config.on_policy() != self._on_policy:
raise ValueError(
"Algorithms in EnsembleAlgorithmConfig must either be all on_policy or all off_policy"
)

if (
not self.supports_discrete_actions()
and not self.supports_continuous_actions()
):
raise ValueError(
"Ensemble algorithm does not support discrete actions nor continuous actions."
" Make sure that at least one type of action is supported across all the algorithms used."
)

def get_algorithm(self, experiment) -> Algorithm:
return self.associated_class()(
algorithms_map={
group: algorithm_config.get_algorithm(experiment)
for group, algorithm_config in self.algorithm_configs_map.items()
},
experiment=experiment,
)

@classmethod
def get_from_yaml(cls, path: Optional[str] = None):
raise NotImplementedError

@staticmethod
def associated_class() -> Type[Algorithm]:
return EnsembleAlgorithm

def on_policy(self) -> bool:
return self._on_policy

def supports_continuous_actions(self) -> bool:
supports_continuous_actions = True
for algorithm_config in self.algorithm_configs_map.values():
supports_continuous_actions *= (
algorithm_config.supports_continuous_actions()
)
return supports_continuous_actions

def supports_discrete_actions(self) -> bool:
supports_discrete_actions = True
for algorithm_config in self.algorithm_configs_map.values():
supports_discrete_actions *= algorithm_config.supports_discrete_actions()
return supports_discrete_actions

def has_independent_critic(self) -> bool:
has_independent_critic = False
for algorithm_config in self.algorithm_configs_map.values():
has_independent_critic += algorithm_config.has_independent_critic()
return has_independent_critic

def has_centralized_critic(self) -> bool:
has_centralized_critic = False
for algorithm_config in self.algorithm_configs_map.values():
has_centralized_critic += algorithm_config.has_centralized_critic()
return has_centralized_critic

def has_critic(self) -> bool:
return self.has_centralized_critic() or self.has_independent_critic()
38 changes: 38 additions & 0 deletions examples/ensemble/ensemble_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


from benchmarl.algorithms import EnsembleAlgorithmConfig, IsacConfig, MaddpgConfig
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models import MlpConfig


if __name__ == "__main__":

# Loads from "benchmarl/conf/experiment/base_experiment.yaml"
experiment_config = ExperimentConfig.get_from_yaml()

# Loads from "benchmarl/conf/task/vmas/simple_tag.yaml"
task = VmasTask.SIMPLE_TAG.get_from_yaml()

# Loads from "benchmarl/conf/model/layers/mlp.yaml"
model_config = MlpConfig.get_from_yaml()
critic_model_config = MlpConfig.get_from_yaml()

algorithm_config = EnsembleAlgorithmConfig(
{"agent": MaddpgConfig.get_from_yaml(), "adversary": IsacConfig.get_from_yaml()}
)

experiment = Experiment(
task=task,
algorithm_config=algorithm_config,
model_config=model_config,
critic_model_config=critic_model_config,
seed=0,
config=experiment_config,
)
experiment.run()
44 changes: 44 additions & 0 deletions examples/ensemble/ensemble_algorithm_and_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


from benchmarl.algorithms import EnsembleAlgorithmConfig, IppoConfig, MappoConfig
from benchmarl.environments import VmasTask
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models import MlpConfig
from models import DeepsetsConfig, EnsembleModelConfig, GnnConfig

if __name__ == "__main__":

# Loads from "benchmarl/conf/experiment/base_experiment.yaml"
experiment_config = ExperimentConfig.get_from_yaml()

# Loads from "benchmarl/conf/task/vmas/simple_tag.yaml"
task = VmasTask.SIMPLE_TAG.get_from_yaml()

algorithm_config = EnsembleAlgorithmConfig(
{"agent": MappoConfig.get_from_yaml(), "adversary": IppoConfig.get_from_yaml()}
)

model_config = EnsembleModelConfig(
{"agent": MlpConfig.get_from_yaml(), "adversary": GnnConfig.get_from_yaml()}
)
critic_model_config = EnsembleModelConfig(
{
"agent": DeepsetsConfig.get_from_yaml(),
"adversary": MlpConfig.get_from_yaml(),
}
)

experiment = Experiment(
task=task,
algorithm_config=algorithm_config,
model_config=model_config,
critic_model_config=critic_model_config,
seed=0,
config=experiment_config,
)
experiment.run()

0 comments on commit 8113962

Please sign in to comment.