Skip to content

Commit

Permalink
Docs
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Nov 25, 2023
1 parent 3a6ae9d commit d2bde6b
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 10 deletions.
16 changes: 8 additions & 8 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Algorithm(ABC):
This should be overridden by implemented algorithms
and all abstract methods should be implemented.
Args:
Args:
experiment (Experiment): the experiment class
"""

Expand Down Expand Up @@ -104,14 +104,13 @@ def _check_specs(self):
def get_loss_and_updater(self, group: str) -> Tuple[LossModule, TargetNetUpdater]:
"""
Get the LossModule and TargetNetUpdater for a specific group.
This function calls the abstract self._get_loss() which needs to be implemented.
This function calls the abstract :class:`~benchmarl.algorithms.Algorithm._get_loss()` which needs to be implemented.
The function will cache the output at the first call and return the cached values in future calls.
Args:
group (str): agent group of the loss and updater
Returns: LossModule and TargetNetUpdater for the group
"""
if group not in self._losses_and_updaters.keys():
action_space = self.action_spec[group, "action"]
Expand Down Expand Up @@ -144,7 +143,7 @@ def get_replay_buffer(
) -> ReplayBuffer:
"""
Get the ReplayBuffer for a specific group.
This function will check self.on_policy and create the buffer accordingly
This function will check ``self.on_policy`` and create the buffer accordingly
Args:
group (str): agent group of the loss and updater
Expand All @@ -165,7 +164,7 @@ def get_replay_buffer(
def get_policy_for_loss(self, group: str) -> TensorDictModule:
"""
Get the non-explorative policy for a specific group loss.
This function calls the abstract self._get_policy_for_loss() which needs to be implemented.
This function calls the abstract :class:`~benchmarl.algorithms.Algorithm._get_policy_for_loss()` which needs to be implemented.
The function will cache the output at the first call and return the cached values in future calls.
Args:
Expand All @@ -192,7 +191,7 @@ def get_policy_for_loss(self, group: str) -> TensorDictModule:
def get_policy_for_collection(self) -> TensorDictSequential:
"""
Get the explorative policy for all groups together.
This function calls the abstract self._get_policy_for_collection() which needs to be implemented.
This function calls the abstract :class:`~benchmarl.algorithms.Algorithm._get_policy_for_collection()` which needs to be implemented.
The function will cache the output at the first call and return the cached values in future calls.
Returns: TensorDictSequential representing all explorative policies
Expand All @@ -217,7 +216,7 @@ def get_policy_for_collection(self) -> TensorDictSequential:
def get_parameters(self, group: str) -> Dict[str, Iterable]:
"""
Get the dictionary mapping loss names to the relative parameters to optimize for a given group.
This function calls the abstract self._get_parameters() which needs to be implemented.
This function calls the abstract :class:`~benchmarl.algorithms.Algorithm._get_parameters()` which needs to be implemented.
Returns: a dictionary mapping loss names to a parameters' list
"""
Expand Down Expand Up @@ -332,6 +331,7 @@ class AlgorithmConfig:
def get_algorithm(self, experiment) -> Algorithm:
"""
Main function to turn the config into the associated algorithm
Args:
experiment (Experiment): the experiment class
Expand Down Expand Up @@ -361,7 +361,7 @@ def get_from_yaml(cls, path: Optional[str] = None):
Args:
path (str, optional): The full path of the yaml file to load from.
If None, it will default to
benchmarl/conf/algorithm/self.associated_class().__name__
``benchmarl/conf/algorithm/self.associated_class().__name__``
Returns: the loaded AlgorithmConfig
"""
Expand Down
12 changes: 12 additions & 0 deletions benchmarl/algorithms/iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@


class Iddpg(Algorithm):
"""Same as :class:`~benchmarkl.algorithms.Maddpg` (from `https://arxiv.org/abs/1706.02275 <https://arxiv.org/abs/1706.02275>`__) but with decentralized critics.
Args:
share_param_critic (bool): Whether to share the parameters of the critics withing agent groups
loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
delay_value (bool): whether to separate the target value networks from the value networks used for
data collection.
"""

def __init__(
self, share_param_critic: bool, loss_function: str, delay_value: bool, **kwargs
):
Expand Down Expand Up @@ -227,6 +237,8 @@ def get_value_module(self, group: str) -> TensorDictModule:

@dataclass
class IddpgConfig(AlgorithmConfig):
"""Configuration dataclass for :class:`~benchmarkl.algorithms.Iddpg`."""

share_param_critic: bool = MISSING
loss_function: str = MISSING
delay_value: bool = MISSING
Expand Down
17 changes: 17 additions & 0 deletions benchmarl/algorithms/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@


class Ippo(Algorithm):
"""Independent PPO (from `https://arxiv.org/abs/2011.09533 <https://arxiv.org/abs/2011.09533>`__).
Args:
share_param_critic (bool): Whether to share the parameters of the critics withing agent groups
clip_epsilon (scalar): weight clipping threshold in the clipped PPO loss equation.
entropy_coef (scalar): entropy multiplier when computing the total loss.
critic_coef (scalar): critic loss multiplier when computing the total
loss_critic_type (str): loss function for the value discrepancy.
Can be one of "l1", "l2" or "smooth_l1".
lmbda (float): The GAE lambda
scale_mapping (str): positive mapping function to be used with the std.
choices: "softplus", "exp", "relu", "biased_softplus_1";
"""

def __init__(
self,
share_param_critic: bool,
Expand Down Expand Up @@ -270,6 +285,8 @@ def get_critic(self, group: str) -> TensorDictModule:

@dataclass
class IppoConfig(AlgorithmConfig):
"""Configuration dataclass for :class:`~benchmarkl.algorithms.Ippo`."""

share_param_critic: bool = MISSING
clip_epsilon: float = MISSING
entropy_coef: float = MISSING
Expand Down
11 changes: 11 additions & 0 deletions benchmarl/algorithms/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@


class Iql(Algorithm):
"""Independent Q Learning (from `https://www.semanticscholar.org/paper/Multi-Agent-Reinforcement-Learning%3A-Independent-Tan/59de874c1e547399b695337bcff23070664fa66e <https://www.semanticscholar.org/paper/Multi-Agent-Reinforcement-Learning%3A-Independent-Tan/59de874c1e547399b695337bcff23070664fa66e>`__).
Args:
loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
delay_value (bool): whether to separate the target value networks from the value networks used for
data collection.
"""

def __init__(self, delay_value: bool, loss_function: str, **kwargs):
super().__init__(**kwargs)

Expand Down Expand Up @@ -175,6 +184,8 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:

@dataclass
class IqlConfig(AlgorithmConfig):
"""Configuration dataclass for :class:`~benchmarkl.algorithms.Iql`."""

delay_value: bool = MISSING
loss_function: str = MISSING

Expand Down
18 changes: 18 additions & 0 deletions benchmarl/algorithms/isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@


class Isac(Algorithm):
"""Independent Soft Actor Critic.
Args:
share_param_critic (bool): Whether to share the parameters of the critics withing agent groups
num_qvalue_nets (integer): number of Q-Value networks used.
loss_function (str): loss function to be used with
the value function loss.
delay_qvalue ():
target_entropy ():
discrete_target_entropy_weight ():
alpha_init ():
min_alpha ():
max_alpha ():
fixed_alpha ():
scale_mapping ():
"""

def __init__(
self,
share_param_critic: bool,
Expand Down
9 changes: 9 additions & 0 deletions docs/source/_templates/autosummary/class_private.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{{ fullname | escape | underline }}

.. currentmodule:: {{ module }}

.. autoclass:: {{ objname }}
:show-inheritance:
:members:
:undoc-members:
:private-members:
4 changes: 2 additions & 2 deletions docs/source/modules/algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Common
.. autosummary::
:nosignatures:
:toctree: ../generated
:template: autosummary/class.rst
:template: autosummary/class_private.rst

Algorithm
AlgorithmConfig
Expand All @@ -26,7 +26,7 @@ Algorithms
.. autosummary::
:nosignatures:
:toctree: ../generated
:template: autosummary/class.rst
:template: autosummary/class_private.rst

{% for name in benchmarl.algorithms.classes %}
{{ name }}
Expand Down

0 comments on commit d2bde6b

Please sign in to comment.