diff --git a/README.md b/README.md index d6a5961c..c5ee12ce 100644 --- a/README.md +++ b/README.md @@ -260,11 +260,12 @@ when requested, as critics. We provide a set of base models (layers) and a Seque different layers. All the models can be used with or without parameter sharing within an agent group. Here is a table of the models implemented in BenchMARL -| Name | Decentralized | Centralized with local inputs | Centralized with global input | -|--------------------------------|:-------------:|:-----------------------------:|:-----------------------------:| -| [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes | -| [GNN](benchmarl/models/gnn.py) | Yes | Yes | No | -| [CNN](benchmarl/models/cnn.py) | Yes | Yes | Yes | +| Name | Decentralized | Centralized with local inputs | Centralized with global input | +|------------------------------------------|:-------------:|:-----------------------------:|:-----------------------------:| +| [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes | +| [GNN](benchmarl/models/gnn.py) | Yes | Yes | No | +| [CNN](benchmarl/models/cnn.py) | Yes | Yes | Yes | +| [Deepsets](benchmarl/models/deepsets.py) | Yes | Yes | Yes | And the ones that are _work in progress_ diff --git a/benchmarl/conf/model/layers/deepsets.yaml b/benchmarl/conf/model/layers/deepsets.yaml new file mode 100644 index 00000000..103546e3 --- /dev/null +++ b/benchmarl/conf/model/layers/deepsets.yaml @@ -0,0 +1,9 @@ + +name: deepsets + +aggr: "sum" +local_nn_num_cells: [128, 128] +local_nn_activation_class: torch.nn.Tanh +out_features_local_nn: 256 +global_nn_num_cells: [256, 256] +global_nn_activation_class: torch.nn.Tanh diff --git a/benchmarl/models/__init__.py b/benchmarl/models/__init__.py index 33510c48..8bd743be 100644 --- a/benchmarl/models/__init__.py +++ b/benchmarl/models/__init__.py @@ -6,9 +6,24 @@ from .cnn import Cnn, CnnConfig from .common import Model, ModelConfig, SequenceModel, SequenceModelConfig +from .deepsets import Deepsets, DeepsetsConfig from .gnn import Gnn, GnnConfig from .mlp import Mlp, MlpConfig -classes = ["Mlp", "MlpConfig", "Gnn", "GnnConfig", "Cnn", "CnnConfig"] +classes = [ + "Mlp", + "MlpConfig", + "Gnn", + "GnnConfig", + "Cnn", + "CnnConfig", + "Deepsets", + "DeepsetsConfig", +] -model_config_registry = {"mlp": MlpConfig, "gnn": GnnConfig, "cnn": CnnConfig} +model_config_registry = { + "mlp": MlpConfig, + "gnn": GnnConfig, + "cnn": CnnConfig, + "deepsets": DeepsetsConfig, +} diff --git a/benchmarl/models/deepsets.py b/benchmarl/models/deepsets.py new file mode 100644 index 00000000..c1513093 --- /dev/null +++ b/benchmarl/models/deepsets.py @@ -0,0 +1,374 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from __future__ import annotations + +from dataclasses import dataclass, MISSING +from typing import Optional, Sequence, Type + +import torch +from tensordict import TensorDictBase +from torch import nn, Tensor +from torchrl.modules import MLP + +from benchmarl.models.common import Model, ModelConfig + + +class Deepsets(Model): + r"""Deepsets Model from `this paper `__ . + + The BenchMARL Deepsets accepts multiple inputs of 2 types: + + - sets :math:`s` : Tensors of shape ``(*batch,S,F)`` + - arrays :math:`x` : Tensors of shape ``(*batch,F)`` + + The Deepsets model will check that all set inputs have the same shape (excluding the last dimension) + and cat them along that dimension before processing them. + + It will check that all array inputs have the same shape (excluding the last dimension) + and cat them along that dimension. + + It will then compute the output according to the following function. + + .. math:: + + \rho \left (x, \bigoplus_{s\in S}\phi(s) \right ), + + where :math:`\rho,\phi` are MLPs configurable in the model setup. + + The model is useful in various contexts, for example: + + - When used as a policy (``self.centralized==False``, ``self.input_has_agent_dim==True``), it can process + observations with shape ``(*batch,n_agents,S,F)``, reducing them to ``(*batch,n_agents,F)`` + - When used a a centralized crtic with a global state as input + (``self.centralized==True``, ``self.input_has_agent_dim==False``), it can process the global state with shape + ``(*batch,S,F)`` , reducing it to ``(*batch,F)``. + - When used a a centralized crtic with local agent observations as input + (``self.centralized==True``, ``self.input_has_agent_dim==True``), it can process normal agent observations with shape + ``(*batch,n_agents,F)``, reducing them to ``(*batch,F)``. **Note**: If the agents also have set observations + ``(*batch,n_agents,S,F)`` it will apply two deep sets networks. The first will remove the set dimension + in the agents' inputs (``(*batch,n_agents,F)``), and the second will remove the agent dimension (``(*batch,F)``). + Both networks will share the same configuration. + + Args: + aggr (str): The aggregation strategy to use in the Deepsets model. + local_nn_num_cells (Sequence[int]): number of cells of every layer in between the input and output in the :math:`\phi` MLP. + local_nn_activation_class (Type[nn.Module]): activation class to be used in the :math:`\phi` MLP. + out_features_local_nn (int): output features of the :math:`\phi` MLP. + global_nn_num_cells (Sequence[int]): number of cells of every layer in between the input and output in the :math:`\rho` MLP. + global_nn_activation_class (Type[nn.Module]): activation class to be used in the :math:`\rho` MLP. + + + """ + + def __init__( + self, + aggr: str, + local_nn_num_cells: Sequence[int], + local_nn_activation_class: Type[nn.Module], + out_features_local_nn: int, + global_nn_num_cells: Sequence[int], + global_nn_activation_class: Type[nn.Module], + **kwargs, + ): + + super().__init__(**kwargs) + self.aggr = aggr + self.local_nn_num_cells = local_nn_num_cells + self.local_nn_activation_class = local_nn_activation_class + self.global_nn_num_cells = global_nn_num_cells + self.global_nn_activation_class = global_nn_activation_class + self.out_features_local_nn = out_features_local_nn + + self.input_local_set_features = sum( + [self.input_spec[key].shape[-1] for key in self.set_in_keys_local] + ) + self.input_local_tensor_features = sum( + [self.input_spec[key].shape[-1] for key in self.tensor_in_keys_local] + ) + self.input_global_set_features = sum( + [self.input_spec[key].shape[-1] for key in self.set_in_keys_global] + ) + self.input_global_tensor_features = sum( + [self.input_spec[key].shape[-1] for key in self.tensor_in_keys_global] + ) + + self.output_features = self.output_leaf_spec.shape[-1] + + if self.input_local_set_features > 0: # Need local deepsets + self.local_deepsets = nn.ModuleList( + [ + self._make_deepsets_net( + in_features=self.input_local_set_features, + out_features_local_nn=self.out_features_local_nn, + in_fetures_global_nn=self.out_features_local_nn + + self.input_local_tensor_features, + out_features=( + self.output_features + if not self.centralised + else self.out_features_local_nn + ), + aggr=self.aggr, + local_nn_activation_class=self.local_nn_activation_class, + global_nn_activation_class=self.global_nn_activation_class, + local_nn_num_cells=self.local_nn_num_cells, + global_nn_num_cells=self.global_nn_num_cells, + ) + for _ in range(self.n_agents if not self.share_params else 1) + ] + ) + if self.centralised: # Need global deepsets + self.global_deepsets = nn.ModuleList( + [ + self._make_deepsets_net( + in_features=( + self.input_global_set_features + if self.input_local_set_features == 0 + else self.out_features_local_nn + ), + out_features_local_nn=self.out_features_local_nn, + in_fetures_global_nn=self.out_features_local_nn + + self.input_global_tensor_features, + out_features=self.output_features, + aggr=self.aggr, + local_nn_activation_class=self.local_nn_activation_class, + global_nn_activation_class=self.global_nn_activation_class, + local_nn_num_cells=self.local_nn_num_cells, + global_nn_num_cells=self.global_nn_num_cells, + ) + for _ in range(self.n_agents if not self.share_params else 1) + ] + ) + + def _make_deepsets_net( + self, + in_features: int, + out_features: int, + aggr: str, + local_nn_num_cells: Sequence[int], + local_nn_activation_class: Type[nn.Module], + global_nn_num_cells: Sequence[int], + global_nn_activation_class: Type[nn.Module], + out_features_local_nn: int, + in_fetures_global_nn: int, + ) -> _DeepsetsNet: + local_nn = MLP( + in_features=in_features, + out_features=out_features_local_nn, + num_cells=local_nn_num_cells, + activation_class=local_nn_activation_class, + device=self.device, + ) + global_nn = MLP( + in_features=in_fetures_global_nn, + out_features=out_features, + num_cells=global_nn_num_cells, + activation_class=global_nn_activation_class, + device=self.device, + ) + return _DeepsetsNet(local_nn, global_nn, aggr=aggr) + + def _perform_checks(self): + super()._perform_checks() + + input_shape_tensor_local = None + self.tensor_in_keys_local = [] + input_shape_set_local = None + self.set_in_keys_local = [] + + input_shape_tensor_global = None + self.tensor_in_keys_global = [] + input_shape_set_global = None + self.set_in_keys_global = [] + + error_invalid_input = ValueError( + f"DeepSet set inputs should all have the same shape up to the last dimension, got {self.input_spec}" + ) + + for input_key, input_spec in self.input_spec.items(True, True): + if self.input_has_agent_dim and len(input_spec.shape) == 3: + self.set_in_keys_local.append(input_key) + if input_shape_set_local is None: + input_shape_set_local = input_spec.shape[:-1] + elif input_spec.shape[:-1] != input_shape_set_local: + raise error_invalid_input + elif self.input_has_agent_dim and len(input_spec.shape) == 2: + self.tensor_in_keys_local.append(input_key) + if input_shape_tensor_local is None: + input_shape_tensor_local = input_spec.shape[:-1] + elif input_spec.shape[:-1] != input_shape_tensor_local: + raise error_invalid_input + elif not self.input_has_agent_dim and len(input_spec.shape) == 2: + self.set_in_keys_global.append(input_key) + if input_shape_set_global is None: + input_shape_set_global = input_spec.shape[:-1] + elif input_spec.shape[:-1] != input_shape_set_global: + raise error_invalid_input + elif not self.input_has_agent_dim and len(input_spec.shape) == 1: + self.tensor_in_keys_global.append(input_key) + if input_shape_tensor_global is None: + input_shape_tensor_global = input_spec.shape[:-1] + elif input_spec.shape[:-1] != input_shape_tensor_global: + raise error_invalid_input + else: + raise ValueError( + f"DeepSets input value {input_key} from {self.input_spec} has an invalid shape" + ) + + # Centralized model not needing any local deepsets + if ( + self.centralised + and not len(self.set_in_keys_local) + and self.input_has_agent_dim + ): + self.set_in_keys_global = self.tensor_in_keys_local + input_shape_set_global = input_shape_tensor_local + self.tensor_in_keys_local = [] + + if (not self.centralised and not len(self.set_in_keys_local)) or ( + self.centralised + and not self.input_has_agent_dim + and not len(self.set_in_keys_global) + ): + raise ValueError("DeepSets found no set inputs, maybe use an MLP?") + + if len(self.set_in_keys_local) and input_shape_set_local[-2] != self.n_agents: + raise ValueError() + if ( + len(self.tensor_in_keys_local) + and input_shape_tensor_local[-1] != self.n_agents + ): + raise ValueError() + if ( + len(self.set_in_keys_global) + and self.input_has_agent_dim + and input_shape_set_global[-1] != self.n_agents + ): + raise ValueError() + + if ( + self.output_has_agent_dim + and ( + self.output_leaf_spec.shape[-2] != self.n_agents + or len(self.output_leaf_spec.shape) != 2 + ) + ) or (not self.output_has_agent_dim and len(self.output_leaf_spec.shape) != 1): + raise ValueError() + + def _forward(self, tensordict: TensorDictBase) -> TensorDictBase: + if len(self.set_in_keys_local): + # Local deep sets + input_local_sets = torch.cat( + [tensordict.get(in_key) for in_key in self.set_in_keys_local], dim=-1 + ) + input_local_tensors = None + if len(self.tensor_in_keys_local): + input_local_tensors = torch.cat( + [tensordict.get(in_key) for in_key in self.tensor_in_keys_local], + dim=-1, + ) + if self.share_params: + local_output = self.local_deepsets[0]( + input_local_sets, input_local_tensors + ) + else: + local_output = torch.stack( + [ + net(input_local_sets, input_local_tensors)[..., i, :] + for i, net in enumerate(self.local_deepsets) + ], + dim=-2, + ) + else: + local_output = None + + if self.centralised: + if local_output is None: + # gather local output + local_output = torch.cat( + [tensordict.get(in_key) for in_key in self.set_in_keys_global], + dim=-1, + ) + # Global deepsets + input_global_tensors = None + if len(self.tensor_in_keys_global): + input_global_tensors = torch.cat( + [tensordict.get(in_key) for in_key in self.tensor_in_keys_global], + dim=-1, + ) + if self.share_params: + global_output = self.global_deepsets[0]( + local_output, input_global_tensors + ) + else: + global_output = torch.stack( + [ + net(local_output, input_global_tensors) + for i, net in enumerate(self.global_deepsets) + ], + dim=-2, + ) + tensordict.set(self.out_key, global_output) + else: + tensordict.set(self.out_key, local_output) + + return tensordict + + +class _DeepsetsNet(nn.Module): + """https://arxiv.org/abs/1703.06114""" + + def __init__( + self, + local_nn: torch.nn.Module, + global_nn: torch.nn.Module, + set_dim: int = -2, + aggr: str = "sum", + ): + super().__init__() + self.aggr = aggr + self.set_dim = set_dim + self.local_nn = local_nn + self.global_nn = global_nn + + def forward(self, x: Tensor, extra_global_input: Optional[Tensor]) -> Tensor: + x = self.local_nn(x) + x = self.reduce(x, dim=self.set_dim, aggr=self.aggr) + if extra_global_input is not None: + x = torch.cat([x, extra_global_input], dim=-1) + x = self.global_nn(x) + return x + + @staticmethod + def reduce(x: Tensor, dim: int, aggr: str) -> Tensor: + if aggr == "sum" or aggr == "add": + return torch.sum(x, dim=dim) + elif aggr == "mean": + return torch.mean(x, dim=dim) + elif aggr == "max": + return torch.max(x, dim=dim)[0] + elif aggr == "min": + return torch.min(x, dim=dim)[0] + elif aggr == "mul": + return torch.prod(x, dim=dim) + + +@dataclass +class DeepsetsConfig(ModelConfig): + """Dataclass config for a :class:`~benchmarl.models.Deepsets`.""" + + aggr: str = MISSING + out_features_local_nn: int = MISSING + + local_nn_num_cells: Sequence[int] = MISSING + local_nn_activation_class: Type[nn.Module] = MISSING + + global_nn_num_cells: Sequence[int] = MISSING + global_nn_activation_class: Type[nn.Module] = MISSING + + @staticmethod + def associated_class(): + return Deepsets diff --git a/docs/source/concepts/components.rst b/docs/source/concepts/components.rst index 50dfe47d..f8e22871 100644 --- a/docs/source/concepts/components.rst +++ b/docs/source/concepts/components.rst @@ -107,12 +107,14 @@ agent group. Here is a table of the models implemented in BenchMARL .. table:: Models in BenchMARL - +---------------------------------+---------------+-------------------------------+-------------------------------+ - | Name | Decentralized | Centralized with local inputs | Centralized with global input | - +=================================+===============+===============================+===============================+ - | :class:`~benchmarl.models.Mlp` | Yes | Yes | Yes | - +---------------------------------+---------------+-------------------------------+-------------------------------+ - | :class:`~benchmarl.models.Gnn` | Yes | Yes | No | - +---------------------------------+---------------+-------------------------------+-------------------------------+ - | :class:`~benchmarl.models.Cnn` | Yes | Yes | Yes | - +---------------------------------+---------------+-------------------------------+-------------------------------+ + +-------------------------------------+---------------+-------------------------------+-------------------------------+ + | Name | Decentralized | Centralized with local inputs | Centralized with global input | + +=====================================+===============+===============================+===============================+ + | :class:`~benchmarl.models.Mlp` | Yes | Yes | Yes | + +-------------------------------------+---------------+-------------------------------+-------------------------------+ + | :class:`~benchmarl.models.Gnn` | Yes | Yes | No | + +-------------------------------------+---------------+-------------------------------+-------------------------------+ + | :class:`~benchmarl.models.Cnn` | Yes | Yes | Yes | + +-------------------------------------+---------------+-------------------------------+-------------------------------+ + | :class:`~benchmarl.models.Deepsets` | Yes | Yes | Yes | + +-------------------------------------+---------------+-------------------------------+-------------------------------+ diff --git a/docs/source/conf.py b/docs/source/conf.py index 4232fd7c..12b93c25 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -23,6 +23,7 @@ "sphinx.ext.napoleon", "sphinx.ext.intersphinx", "sphinx.ext.viewcode", + "sphinx.ext.mathjax", "patch", ] diff --git a/setup.cfg b/setup.cfg index dd9f5082..b786fccc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,7 @@ max-line-length = 120 [flake8] # note: we ignore all 501s (line too long) anyway as they're taken care of by black max-line-length = 79 -ignore = E203, E402, W503, W504, E501 +ignore = E203, E402, W503, W504, E501, W605 per-file-ignores = __init__.py: F401, F403, F405 test_*.py: F841, E731, E266 diff --git a/test/test_models.py b/test/test_models.py index e33455eb..4af11458 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -29,17 +29,21 @@ def _get_input_and_output_specs( out_features=4, x=12, y=12, + set_size=5, ): if model_name == "cnn": multi_agent_input_shape = (n_agents, x, y, in_features) single_agent_input_shape = (x, y, in_features) + elif model_name == "deepsets": + multi_agent_input_shape = (n_agents, set_size, in_features) + single_agent_input_shape = (set_size, in_features) else: multi_agent_input_shape = (n_agents, in_features) single_agent_input_shape = in_features - other_multi_agent_input_shape = (n_agents, in_features) - other_single_agent_input_shape = in_features + other_multi_agent_input_shape = (n_agents, in_features + 1) + other_single_agent_input_shape = in_features + 1 if input_has_agent_dim: input_spec = CompositeSpec( @@ -216,7 +220,7 @@ def test_share_params_between_models( or (isinstance(model_name, list) and model_name[0] != "gnn") ): pytest.skip("gnn model needs agent dim as input") - torch.manual_seed(0) + torch.manual_seed(1) input_spec, output_spec = _get_input_and_output_specs( centralised=centralised, @@ -225,13 +229,6 @@ def test_share_params_between_models( share_params=share_params, n_agents=n_agents, ) - input_spec2, output_spec2 = _get_input_and_output_specs( - centralised=centralised, - input_has_agent_dim=input_has_agent_dim, - model_name=model_name if isinstance(model_name, str) else model_name[0], - share_params=share_params, - n_agents=n_agents, - ) if isinstance(model_name, List): config = SequenceModelConfig( @@ -254,8 +251,8 @@ def test_share_params_between_models( action_spec=None, ) second_model = config.get_model( - input_spec=input_spec2, - output_spec=output_spec2, + input_spec=input_spec, + output_spec=output_spec, share_params=share_params, centralised=centralised, input_has_agent_dim=input_has_agent_dim, @@ -372,3 +369,75 @@ def test_gnn_edge_attrs( obs_input = input_spec.expand(batch_size).rand() output = gnn(obs_input) assert output_spec.expand(batch_size).is_in(output) + + +class TestDeepsets: + @pytest.mark.parametrize("share_params", [True, False]) + @pytest.mark.parametrize("batch_size", [(), (2,), (3, 2)]) + def test_special_case_centralized_critic_from_agent_tensors( + self, + share_params, + batch_size, + centralised=True, + input_has_agent_dim=True, + model_name="deepsets", + n_agents=3, + in_features=4, + out_features=2, + ): + + torch.manual_seed(0) + + config = model_config_registry[model_name].get_from_yaml() + + multi_agent_input_shape = (n_agents, in_features) + other_multi_agent_input_shape = (n_agents, in_features) + + input_spec = CompositeSpec( + { + "agents": CompositeSpec( + { + "observation": UnboundedContinuousTensorSpec( + shape=multi_agent_input_shape + ), + "other": UnboundedContinuousTensorSpec( + shape=other_multi_agent_input_shape + ), + }, + shape=(n_agents,), + ) + } + ) + + if output_has_agent_dim(centralised=centralised, share_params=share_params): + output_spec = CompositeSpec( + { + "agents": CompositeSpec( + { + "out": UnboundedContinuousTensorSpec( + shape=(n_agents, out_features) + ) + }, + shape=(n_agents,), + ) + }, + ) + else: + output_spec = CompositeSpec( + {"out": UnboundedContinuousTensorSpec(shape=(out_features,))}, + ) + + model = config.get_model( + input_spec=input_spec, + output_spec=output_spec, + share_params=share_params, + centralised=centralised, + input_has_agent_dim=input_has_agent_dim, + n_agents=n_agents, + device="cpu", + agent_group="agents", + action_spec=None, + ) + input_td = input_spec.expand(batch_size).rand() + out_td = model(input_td) + assert output_spec.expand(batch_size).is_in(out_td)