Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add policy documentation links to policy_kwargs parameter #2050

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Documentation:
^^^^^^^^^^^^^^
- Added Decisions and Dragons to resources. (@jmacglashan)
- Updated PyBullet example, now compatible with Gymnasium
- Added link to policies for ``policy_kwargs`` parameter (@kplers)

Release 2.4.0 (2024-11-18)
--------------------------
Expand Down Expand Up @@ -1738,4 +1739,4 @@ And all the contributors:
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
@brn-dev @jmacglashan
@brn-dev @jmacglashan @kplers
6 changes: 4 additions & 2 deletions docs/modules/a2c.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments.

A2C is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using ``SubprocVecEnv`` instead of the default ``DummyVecEnv``:

.. code-block::
.. code-block:: python

from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
Expand All @@ -88,7 +88,7 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments.
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
model = A2C("MlpPolicy", env, device="cpu")
model.learn(total_timesteps=25_000)

For more information, see :ref:`Vectorized Environments <vec_env>`, `Issue #1245 <https://github.com/DLR-RM/stable-baselines3/issues/1245>`_ or the `Multiprocessing notebook <https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb>`_.


Expand Down Expand Up @@ -165,6 +165,8 @@ Parameters
:inherited-members:


.. _a2c_policies:

A2C Policies
-------------

Expand Down
6 changes: 4 additions & 2 deletions docs/modules/ppo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments.

PPO is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using ``SubprocVecEnv`` instead of the default ``DummyVecEnv``:

.. code-block::
.. code-block:: python

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
Expand All @@ -102,7 +102,7 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments.
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
model = PPO("MlpPolicy", env, device="cpu")
model.learn(total_timesteps=25_000)

For more information, see :ref:`Vectorized Environments <vec_env>`, `Issue #1245 <https://github.com/DLR-RM/stable-baselines3/issues/1245#issuecomment-1435766949>`_ or the `Multiprocessing notebook <https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb>`_.

Results
Expand Down Expand Up @@ -178,6 +178,8 @@ Parameters
:inherited-members:


.. _ppo_policies:

PPO Policies
-------------

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class A2C(OnPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`a2c_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DDPG(TD3):
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ddpg_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class DQN(OffPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`dqn_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class PPO(OnPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ppo_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class SAC(OffPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`sac_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class TD3(OffPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`td3_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
Expand Down