Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jun 13, 2024
1 parent defcd49 commit 6670369
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 13 deletions.
38 changes: 25 additions & 13 deletions benchmarl/models/deepsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,50 @@


class Deepsets(Model):
"""Deepsets Model from https://arxiv.org/abs/1703.06114
r"""Deepsets Model from `this paper <https://arxiv.org/abs/1703.06114>`__ .
The BenchMARL Deepsets accepts multiple inputs of 2 types:
- sets :math:`s`: Tensors of shape ``(*batch,S,F)``
- arrays :math:`x`: Tensors of shape ``(*batch,F)``
- sets :math:`s` : Tensors of shape ``(*batch,S,F)``
- arrays :math:`x` : Tensors of shape ``(*batch,F)``
The Deepsets model will check that all set inputs have the same shape (excluding the last dimension)
and cat them along that dimension before processing them.
It will check that all array inputs have the same shape (excluding the last dimension)
and cat them along that dimension.
It will then compute the output according to the following function:
It will then compute the output according to the following function.
.. math::
\rho \left (x, \bigoplus_{s\in S}\phi(s) \right )
\rho \left (x, \bigoplus_{s\in S}\phi(s) \right ),
Where :math:`\rho,\phi` are MLPs configurable in the model setup.
The model is useful in various contexts, for example:
- When used as a policy (``self.centralized==False``, ``self.input_has_agent_dim==True``), it can process
observations with shape ``(*batch,n_agents,S,F)``, reducing them to ``(*batch,n_agents,F)``
- When used a a centralized crtic with a global state as input
(``self.centralized==True``, ``self.input_has_agent_dim==False``), it can process the global state with shape
``(*batch,S,F)`` , reducing it to ``(*batch,F)``.
- When used a a centralized crtic with local agent observations as input
(``self.centralized==True``, ``self.input_has_agent_dim==True``), it can process normal agent observations with shape
``(*batch,n_agents,F)``, reducing them to ``(*batch,F)``. **Note**: If the agents also have set observations
``(*batch,n_agents,S,F)`` it will apply two deep sets networks. The first will remove the set dimension
in the agents' inputs (``(*batch,n_agents,F)``), and the second will remove the agent dimension (``(*batch,F)``).
Both networks will share the same configuration.
Args:
aggr (str): The aggregation strategy to use in the Deepsets model.
local_nn_num_cells (Sequence[int]): number of cells of every layer in between the input and output in the
:math:`\phi` MLP.
local_nn_activation_class (Type[nn.Module]): activation class to be used in the
:math:`\phi` MLP.
local_nn_num_cells (Sequence[int]): number of cells of every layer in between the input and output in the :math:`\phi` MLP.
local_nn_activation_class (Type[nn.Module]): activation class to be used in the :math:`\phi` MLP.
out_features_local_nn (int): output features of the :math:`\phi` MLP.
global_nn_num_cells (Sequence[int]): number of cells of every layer in between the input and output in the
:math:`\rho` MLP.
global_nn_activation_class (Type[nn.Module]): activation class to be used in the
:math:`\prho` MLP.
global_nn_num_cells (Sequence[int]): number of cells of every layer in between the input and output in the :math:`\rho` MLP.
global_nn_activation_class (Type[nn.Module]): activation class to be used in the :math:`\rho` MLP.
"""

def __init__(
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"sphinx.ext.napoleon",
"sphinx.ext.intersphinx",
"sphinx.ext.viewcode",
"sphinx.ext.mathjax",
"patch",
]

Expand Down

0 comments on commit 6670369

Please sign in to comment.