Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add policy documentation links to policy_kwargs parameter #266

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/modules/ppo_mask.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ Parameters
:members:
:inherited-members:

.. _ppo_mask_policies:

MaskablePPO Policies
--------------------
Expand Down
1 change: 1 addition & 0 deletions docs/modules/ppo_recurrent.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ Parameters
:members:
:inherited-members:

.. _ppo_recurrent_policies:

RecurrentPPO Policies
---------------------
Expand Down
6 changes: 3 additions & 3 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

__all__ = [
"ARS",
"CrossQ",
"MaskablePPO",
"RecurrentPPO",
"QRDQN",
"TQC",
"TRPO",
"CrossQ",
"MaskablePPO",
"RecurrentPPO",
]
2 changes: 1 addition & 1 deletion sb3_contrib/ars/ars.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ARS(BaseAlgorithm):
:param zero_policy: Boolean determining if the passed policy should have it's weights zeroed before training.
:param alive_bonus_offset: Constant added to the reward at each step, used to cancel out alive bonuses.
:param n_eval_episodes: Number of episodes to evaluate each candidate.
:param policy_kwargs: Keyword arguments to pass to the policy on creation
:param policy_kwargs: Keyword arguments to pass to the policy on creation. See :ref:`ars_policies`
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: String with the directory to put tensorboard logs:
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/common/torch_layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

__all__ = ["BatchRenorm1d", "BatchRenorm"]
__all__ = ["BatchRenorm", "BatchRenorm1d"]


class BatchRenorm(torch.nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class CrossQ(OffPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`crossq_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/ppo_mask/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sb3_contrib.ppo_mask.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.ppo_mask.ppo_mask import MaskablePPO

__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "MaskablePPO"]
__all__ = ["CnnPolicy", "MaskablePPO", "MlpPolicy", "MultiInputPolicy"]
2 changes: 1 addition & 1 deletion sb3_contrib/ppo_mask/ppo_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class MaskablePPO(OnPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ppo_mask_policies`
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class RecurrentPPO(OnPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ppo_recurrent_policies`
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/qrdqn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sb3_contrib.qrdqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.qrdqn.qrdqn import QRDQN

__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "QRDQN"]
__all__ = ["QRDQN", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]
2 changes: 1 addition & 1 deletion sb3_contrib/qrdqn/qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class QRDQN(OffPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`qrdqn_policies`
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/tqc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sb3_contrib.tqc.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.tqc.tqc import TQC

__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "TQC"]
__all__ = ["TQC", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]
2 changes: 1 addition & 1 deletion sb3_contrib/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class TQC(OffPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`tqc_policies`
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/trpo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from sb3_contrib.trpo.trpo import TRPO

__all__ = ["CnnPolicy", "MlpPolicy", "MultiInputPolicy", "TRPO"]
__all__ = ["TRPO", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]
2 changes: 1 addition & 1 deletion sb3_contrib/trpo/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class TRPO(OnPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`trpo_policies`
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Expand Down