From 205db60b06a91d2d0330d1279b80751595c8f706 Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Wed, 8 Nov 2023 09:59:31 +0000 Subject: [PATCH] [Feature] More callbacks (#35) * callback * render call back * Render callback is called max_steps-1 times * remove assert * update example --- benchmarl/environments/common.py | 7 ++++++ benchmarl/experiment/callback.py | 34 ++++++++++++++++++++++++---- benchmarl/experiment/experiment.py | 18 ++++++++++----- benchmarl/experiment/logger.py | 5 ++-- examples/callback/custom_callback.py | 9 +++++--- 5 files changed, 58 insertions(+), 15 deletions(-) diff --git a/benchmarl/environments/common.py b/benchmarl/environments/common.py index df2b50c2..c7f49605 100644 --- a/benchmarl/environments/common.py +++ b/benchmarl/environments/common.py @@ -238,6 +238,13 @@ def get_reward_sum_transform(self, env: EnvBase) -> Transform: """ return RewardSum(reset_keys=env.reset_keys) + @staticmethod + def render_callback(experiment, env: EnvBase, data: TensorDictBase): + try: + return env.render(mode="rgb_array") + except TypeError: + return env.render() + def __repr__(self): cls_name = self.__class__.__name__ return f"{cls_name}.{self.name}: (config={self.config})" diff --git a/benchmarl/experiment/callback.py b/benchmarl/experiment/callback.py index 27c1842e..c93b85d8 100644 --- a/benchmarl/experiment/callback.py +++ b/benchmarl/experiment/callback.py @@ -34,12 +34,27 @@ def on_batch_collected(self, batch: TensorDictBase): """ pass - def on_train_end(self, training_td: TensorDictBase): + def on_train_step(self, batch: TensorDictBase, group: str) -> TensorDictBase: """ - A callback called at the end of every training step. + A callback called for every training step. + + Args: + batch (TensorDictBase): tensordict with the training batch + group (str): group name + + Returns: + TensorDictBase: a new tensordict containing the loss values + + """ + pass + + def on_train_end(self, training_td: TensorDictBase, group: str): + """ + A callback called at the end of training. Args: training_td (TensorDictBase): tensordict containing the loss values + group (str): group name """ pass @@ -65,9 +80,20 @@ def on_batch_collected(self, batch: TensorDictBase): for callback in self.callbacks: callback.on_batch_collected(batch) - def on_train_end(self, training_td: TensorDictBase): + def on_train_step(self, batch: TensorDictBase, group: str) -> TensorDictBase: + train_td = None + for callback in self.callbacks: + td = callback.on_train_step(batch, group) + if td is not None: + if train_td is None: + train_td = td + else: + train_td.update(td) + return train_td + + def on_train_end(self, training_td: TensorDictBase, group: str): for callback in self.callbacks: - callback.on_train_end(training_td) + callback.on_train_end(training_td, group) def on_evaluation_end(self, rollouts: List[TensorDictBase]): for callback in self.callbacks: diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 30f7cbc6..ba43a3df 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -7,6 +7,7 @@ from __future__ import annotations import importlib + import os import time from collections import OrderedDict @@ -535,7 +536,6 @@ def _collection_loop(self): step=self.n_iters_performed, ) pbar.set_description(f"mean return = {self.mean_return}", refresh=False) - pbar.update() # Callback self.on_batch_collected(batch) @@ -561,7 +561,7 @@ def _collection_loop(self): ) # Callback - self.on_train_end(training_td) + self.on_train_end(training_td, group) # Exploration update if isinstance(self.group_policies[group], TensorDictSequential): @@ -607,6 +607,7 @@ def _collection_loop(self): and self.total_frames % self.config.checkpoint_interval == 0 ): self._save_experiment() + pbar.update() sampling_start = time.time() self.close() @@ -638,6 +639,7 @@ def _optimizer_loop(self, group: str) -> TensorDictBase: loss_value.backward() grad_norm = self._grad_clip(optimizer) + training_td.set( f"grad_norm_{loss_name}", torch.tensor(grad_norm, device=self.config.train_device), @@ -648,6 +650,11 @@ def _optimizer_loop(self, group: str) -> TensorDictBase: self.replay_buffers[group].update_tensordict_priority(subdata) if self.target_updaters[group] is not None: self.target_updaters[group].step() + + callback_loss = self.on_train_step(subdata, group) + if callback_loss is not None: + training_td.update(callback_loss) + return training_td def _grad_clip(self, optimizer: torch.optim.Optimizer) -> float: @@ -672,10 +679,9 @@ def _evaluation_loop(self): video_frames = [] def callback(env, td): - try: - video_frames.append(env.render(mode="rgb_array")) - except TypeError: - video_frames.append(env.render()) + video_frames.append( + self.task.__class__.render_callback(self, env, td) + ) else: video_frames = None diff --git a/benchmarl/experiment/logger.py b/benchmarl/experiment/logger.py index fb0f5d47..c4bd25be 100644 --- a/benchmarl/experiment/logger.py +++ b/benchmarl/experiment/logger.py @@ -88,7 +88,6 @@ def log_collection( total_frames: int, step: int, ) -> float: - to_log = {} json_metrics = {} for group in self.group_map.keys(): @@ -218,7 +217,9 @@ def log_evaluation( self.log(to_log, step=step) if video_frames is not None: vid = torch.tensor( - np.transpose(video_frames[: rollouts[0].batch_size[0]], (0, 3, 1, 2)), + np.transpose( + video_frames[: rollouts[0].batch_size[0] - 1], (0, 3, 1, 2) + ), dtype=torch.uint8, ).unsqueeze(0) for logger in self.loggers: diff --git a/examples/callback/custom_callback.py b/examples/callback/custom_callback.py index 10957148..f724a93d 100644 --- a/examples/callback/custom_callback.py +++ b/examples/callback/custom_callback.py @@ -11,14 +11,18 @@ from benchmarl.experiment import Experiment, ExperimentConfig from benchmarl.experiment.callback import Callback from benchmarl.models.mlp import MlpConfig -from tensordict import TensorDictBase +from tensordict import TensorDict, TensorDictBase class MyCallbackA(Callback): def on_batch_collected(self, batch: TensorDictBase): print(f"Callback A is doing something with the sampling batch {batch}") - def on_train_end(self, training_td: TensorDictBase): + def on_train_step(self, batch: TensorDictBase, group: str) -> TensorDictBase: + print(f"Callback A is computing a loss with the training tensordict {batch}") + return TensorDict({}, []) + + def on_train_end(self, training_td: TensorDictBase, group: str): print( f"Callback A is doing something with the training tensordict {training_td}" ) @@ -37,7 +41,6 @@ def on_evaluation_end(self, rollouts: List[TensorDictBase]): if __name__ == "__main__": - experiment_config = ExperimentConfig.get_from_yaml() task = VmasTask.BALANCE.get_from_yaml() algorithm_config = MappoConfig.get_from_yaml()