diff --git a/benchmarl/models/deepsets.py b/benchmarl/models/deepsets.py index ee4b1b81..3db840be 100644 --- a/benchmarl/models/deepsets.py +++ b/benchmarl/models/deepsets.py @@ -18,12 +18,12 @@ class Deepsets(Model): - """Deepsets Model from https://arxiv.org/abs/1703.06114 + r"""Deepsets Model from `this paper `__ . The BenchMARL Deepsets accepts multiple inputs of 2 types: - - sets :math:`s`: Tensors of shape ``(*batch,S,F)`` - - arrays :math:`x`: Tensors of shape ``(*batch,F)`` + - sets :math:`s` : Tensors of shape ``(*batch,S,F)`` + - arrays :math:`x` : Tensors of shape ``(*batch,F)`` The Deepsets model will check that all set inputs have the same shape (excluding the last dimension) and cat them along that dimension before processing them. @@ -31,25 +31,37 @@ class Deepsets(Model): It will check that all array inputs have the same shape (excluding the last dimension) and cat them along that dimension. - It will then compute the output according to the following function: + It will then compute the output according to the following function. .. math:: - \rho \left (x, \bigoplus_{s\in S}\phi(s) \right ) + \rho \left (x, \bigoplus_{s\in S}\phi(s) \right ), Where :math:`\rho,\phi` are MLPs configurable in the model setup. + The model is useful in various contexts, for example: + + - When used as a policy (``self.centralized==False``, ``self.input_has_agent_dim==True``), it can process + observations with shape ``(*batch,n_agents,S,F)``, reducing them to ``(*batch,n_agents,F)`` + - When used a a centralized crtic with a global state as input + (``self.centralized==True``, ``self.input_has_agent_dim==False``), it can process the global state with shape + ``(*batch,S,F)`` , reducing it to ``(*batch,F)``. + - When used a a centralized crtic with local agent observations as input + (``self.centralized==True``, ``self.input_has_agent_dim==True``), it can process normal agent observations with shape + ``(*batch,n_agents,F)``, reducing them to ``(*batch,F)``. **Note**: If the agents also have set observations + ``(*batch,n_agents,S,F)`` it will apply two deep sets networks. The first will remove the set dimension + in the agents' inputs (``(*batch,n_agents,F)``), and the second will remove the agent dimension (``(*batch,F)``). + Both networks will share the same configuration. + Args: aggr (str): The aggregation strategy to use in the Deepsets model. - local_nn_num_cells (Sequence[int]): number of cells of every layer in between the input and output in the - :math:`\phi` MLP. - local_nn_activation_class (Type[nn.Module]): activation class to be used in the - :math:`\phi` MLP. + local_nn_num_cells (Sequence[int]): number of cells of every layer in between the input and output in the :math:`\phi` MLP. + local_nn_activation_class (Type[nn.Module]): activation class to be used in the :math:`\phi` MLP. out_features_local_nn (int): output features of the :math:`\phi` MLP. - global_nn_num_cells (Sequence[int]): number of cells of every layer in between the input and output in the - :math:`\rho` MLP. - global_nn_activation_class (Type[nn.Module]): activation class to be used in the - :math:`\prho` MLP. + global_nn_num_cells (Sequence[int]): number of cells of every layer in between the input and output in the :math:`\rho` MLP. + global_nn_activation_class (Type[nn.Module]): activation class to be used in the :math:`\rho` MLP. + + """ def __init__( diff --git a/docs/source/conf.py b/docs/source/conf.py index 4232fd7c..12b93c25 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -23,6 +23,7 @@ "sphinx.ext.napoleon", "sphinx.ext.intersphinx", "sphinx.ext.viewcode", + "sphinx.ext.mathjax", "patch", ]