Skip to content

Commit

Permalink
Add Documentation for Cartpole Training (#180)
Browse files Browse the repository at this point in the history
Basic documentation on how to run the cartpole test, what the arguments
represent, and how to track metrics in tensorboard.

---------

Co-authored-by: Luc Baracat <[email protected]>
  • Loading branch information
jaxs-ribs and Luc Baracat authored Oct 24, 2023
1 parent 8b934eb commit b2c9031
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 37 deletions.
Binary file added docs/cart_pole.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
25 changes: 25 additions & 0 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# 🔥 Getting Started

In the `/experiments` folder, example runs can be found for different Gymnasium environments.

For example, you can run the cartpole example using DQN with the following command:

```python
pdm run python experiments/train_dqn_cartpole.py
```

![Alt Text](cart_pole.gif)

This comes with a lot of predefined arguments, such as the learning rate, the amount of hidden layers, the batch size, etc. You can find all the arguments in the `experiments/train_dqn_cartpole.py` file.

## 📊 Tensorboard

To visualize the training process, you can use Tensorboard. To do so, run the following command:

```bash
pdm run tensorboard --logdir ./mllogs
```

This will start a Tensorboard server on `localhost:6006`. You can now open your browser and go to `localhost:6006` to see the training process where you can see the rewards over time, the loss over time, etc.

![Alt Text](tensorboard.png)
2 changes: 2 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ to build other algorithms.
coding-standard
📚 Editing documentation <documentation.md>
🌡 Metrics <metrics.md>
🚀 Getting Started <getting_started.md>

.. toctree::
:maxdepth: 6
Expand All @@ -58,6 +59,7 @@ to build other algorithms.
.. include:: adr/doc.md
.. include:: documentation.md
.. include:: metrics.md
.. include:: getting_started.md
:parser: myst_parser.sphinx_

Expand Down
Binary file added docs/tensorboard.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
149 changes: 112 additions & 37 deletions experiments/gym/train_dqn_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@


def _make_env():
"""Create the environment for the experiment, the environment is created in a thunk to avoid
creating multiple environments in the same process. This is important for the vectorized
environments.
Returns:
(Callable[[], gym.Env]): The thunk that creates the environment
"""

def _thunk():
env = gym.make("CartPole-v1")
env = gym.wrappers.FrameStack(env, 3)
Expand All @@ -37,6 +44,18 @@ def _thunk():


class QNet(nn.Module):
"""
Q-Network class for Q-Learning. It takes observations and returns Q-values for actions.
Attributes:
network (nn.Sequential): Neural network for computing Q-values.
Args:
num_obs (int): Dimensionality of observations.
num_actions (int): Number of possible actions.
hidden_dims (list of int): Dimensions of hidden layers.
"""

def __init__(self, num_obs, num_actions, hidden_dims):
super(QNet, self).__init__()

Expand All @@ -53,10 +72,37 @@ def __init__(self, num_obs, num_actions, hidden_dims):
self.network = nn.Sequential(*layers)

def forward(self, obs):
"""
Forward pass for the Q-Network.
Args:
obs (Tensor): Observations.
Returns:
Tensor: Q-values for each action.
"""
return self.network(obs)


class DQNPolicy(nn.Module):
"""
DQN Policy class to handle action selection with epsilon-greedy strategy.
Attributes:
q_net (QNet): Q-Network to evaluate Q-values.
initial_epsilon (float): Initial value of epsilon in epsilon-greedy.
target_epsilon (float): Target value of epsilon.
step_count (int): Counter for steps taken.
epsilon_decay_duration (int): Steps over which epsilon is decayed.
log_epsilon (bool): Flag to log epsilon values.
Args:
q_net (QNet): Q-Network.
epsilon_range (list of float): Initial and target epsilon for epsilon-greedy.
epsilon_decay_duration (int): Number of steps over which epsilon will decay.
log_epsilon (bool): Whether to log epsilon values or not.
"""

def __init__(
self, q_net, epsilon_range=[0.9, 0.05], epsilon_decay_duration=10_000, log_epsilon=True
):
Expand All @@ -71,6 +117,15 @@ def __init__(

# Returns the index of the chosen action
def forward(self, state):
"""
Forward pass for action selection.
Args:
state (Tensor): The state observations.
Returns:
Tensor: Indices of chosen actions for each environment.
"""
with torch.no_grad():
epsilon = self.target_epsilon + (self.initial_epsilon - self.target_epsilon) * math.exp(
-1.0 * self.step_count / self.epsilon_decay_duration
Expand Down Expand Up @@ -120,13 +175,16 @@ def create_memory(
(tuple[TableMemoryProxy, MemoryLoader]): A proxy for the memory and a dataloader
"""
# Create the memory
table = DictObsNStepTable(
spaces=space,
use_terminal_column=False,
maxlen=memory_size,
device=device,
)
# The memory proxy is used to upload the data to the memory
memory_proxy = TableMemoryProxy(table=table, use_terminal=False)
# The data loader is used to sample the data from the memory
data_loader = MemoryLoader(
table=table,
rollout_count=batch_size // len_rollout,
Expand All @@ -152,33 +210,20 @@ def create_complementary_callbacks(
Returns:
(list[Callback]): the full list of callbacks for the training
"""
if args.use_wandb:
from emote.callbacks.wb_logger import WBLogger

config = {
"wandb_project": args.name,
"wandb_run": args.wandb_run,
"hidden_dims": args.hidden_layer_size,
"batch_size": args.batch_size,
"learning_rate": args.actor_lr,
"rollout_len": args.rollout_length,
}
logger = WBLogger(
callbacks=logged_cbs,
config=config,
log_interval=100,
)
else:
logger = TensorboardLogger(
logged_cbs,
SummaryWriter(log_dir=args.log_dir + "/" + args.name + "_{}".format(time.time())),
100,
)
# The logger callback is used for logging the training progress
logger = TensorboardLogger(
logged_cbs,
SummaryWriter(log_dir=args.log_dir + "/" + args.name + "_{}".format(time.time())),
100,
)

# Terminates the training after a certain number of backprop steps
bp_step_terminator = BackPropStepsTerminator(bp_steps=args.bp_steps)
# Callbacks to be used during training
callbacks = logged_cbs + [logger, bp_step_terminator]

if cbs_name_to_checkpoint:
# The checkpointer exports the model weights to the checkpoint directory
checkpointer = Checkpointer(
callbacks=[
cb for cb in logged_cbs if hasattr(cb, "name") and cb.name in cbs_name_to_checkpoint
Expand All @@ -192,9 +237,11 @@ def create_complementary_callbacks(


def main(args):
# Create the environment
env = DictGymWrapper(AsyncVectorEnv([_make_env() for _ in range(args.num_envs)]))
device = torch.device(args.device)

# Define the space in order to create the memory
input_shapes = {k: v.shape for k, v in env.dict_space.state.spaces.items()}
output_shapes = {"actions": env.dict_space.actions.shape}
action_shape = output_shapes["actions"]
Expand Down Expand Up @@ -231,14 +278,17 @@ def main(args):

num_actions = env.action_space.nvec[0]

# Create our two networks and the policy
online_q_net = QNet(num_obs, num_actions, args.hidden_dims)
target_q_net = QNet(num_obs, num_actions, args.hidden_dims)
policy = DQNPolicy(online_q_net)

# Move them to the device
online_q_net = online_q_net.to(device)
target_q_net = target_q_net.to(device)
policy = policy.to(device)

# The agent proxy is responsible for inference
agent_proxy = GenericAgentProxy(
policy,
device=device,
Expand All @@ -248,6 +298,7 @@ def main(args):
spaces=spaces,
)

# Create an optimizer for the online network
optimizers = [
QLoss(
name="q1",
Expand All @@ -258,11 +309,13 @@ def main(args):
]

train_callbacks = optimizers + [
# The QTarget callback is responsible for updating the target network
QTarget(
q_net=online_q_net,
target_q_net=target_q_net,
roll_length=args.rollout_length,
),
# The collector is responsible for the interaction with the environment
ThreadedGymCollector(
env,
agent_proxy,
Expand All @@ -277,29 +330,51 @@ def main(args):
train_callbacks,
)

# The trainer acts as the main callback, responsible for calling all other callbacks
trainer = Trainer(all_callbacks, dataloader)
trainer.train()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, default="cartpole")
parser.add_argument("--log-dir", type=str, default="./mllogs/emote/cartpole")
parser.add_argument("--num-envs", type=int, default=4)
parser.add_argument("--rollout-length", type=int, default=1)
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--hidden-dims", type=list, default=[128, 128])
parser.add_argument("--lr", type=float, default=1e-3, help="The learning rate")
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--bp-steps", type=int, default=50_000)
parser.add_argument("--memory-size", type=int, default=50_000)
parser.add_argument("--export-memory", action="store_true", default=False)
parser.add_argument("--use-wandb", action="store_true")
parser.add_argument("--name", type=str, default="cartpole", help="The name of the experiment")
parser.add_argument(
"--wandb-run",
"--log-dir",
type=str,
default=None,
help="Short display name of run for the W&B UI. Randomly generated by default.",
default="./mllogs/emote/cartpole",
help="Directory where logs will be stored.",
)
parser.add_argument(
"--num-envs", type=int, default=4, help="Number of environments to run in parallel"
)
parser.add_argument(
"--rollout-length",
type=int,
default=1,
help="The length of each rollout. Refers to the number of steps or time-steps taken during a simulated trajectory or rollout when estimating the expected return of a policy.",
)
parser.add_argument("--batch-size", type=int, default=128, help="Size of each training batch")
parser.add_argument(
"--hidden-dims", type=list, default=[128, 128], help="The hidden dimensions of the network"
)
parser.add_argument("--lr", type=float, default=1e-3, help="Learning Rate")
parser.add_argument(
"--device", type=str, default="cpu", help="Device to run the model on, e.g. cpu or cuda:0"
)
parser.add_argument(
"--bp-steps",
type=int,
default=50_000,
help="Number of backpropagation steps until the training run is finished",
)
parser.add_argument(
"--memory-size",
type=int,
default=50_000,
help="The size of the replay buffer. More complex environments require larger replay buffers, as they need more data to learn. Given that cartpole is a simple environment, a replay buffer of size 50_000 is sufficient.",
)
parser.add_argument(
"--export-memory", action="store_true", default=False, help="Whether to export the memory"
)
args = parser.parse_args()
main(args)
Expand Down

0 comments on commit b2c9031

Please sign in to comment.