Skip to content

Commit

Permalink
[Feature] More callbacks (#35)
Browse files Browse the repository at this point in the history
* callback

* render call back

* Render callback is called max_steps-1 times

* remove assert

* update example
  • Loading branch information
matteobettini authored Nov 8, 2023
1 parent ea835d0 commit 205db60
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 15 deletions.
7 changes: 7 additions & 0 deletions benchmarl/environments/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ def get_reward_sum_transform(self, env: EnvBase) -> Transform:
"""
return RewardSum(reset_keys=env.reset_keys)

@staticmethod
def render_callback(experiment, env: EnvBase, data: TensorDictBase):
try:
return env.render(mode="rgb_array")
except TypeError:
return env.render()

def __repr__(self):
cls_name = self.__class__.__name__
return f"{cls_name}.{self.name}: (config={self.config})"
Expand Down
34 changes: 30 additions & 4 deletions benchmarl/experiment/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,27 @@ def on_batch_collected(self, batch: TensorDictBase):
"""
pass

def on_train_end(self, training_td: TensorDictBase):
def on_train_step(self, batch: TensorDictBase, group: str) -> TensorDictBase:
"""
A callback called at the end of every training step.
A callback called for every training step.
Args:
batch (TensorDictBase): tensordict with the training batch
group (str): group name
Returns:
TensorDictBase: a new tensordict containing the loss values
"""
pass

def on_train_end(self, training_td: TensorDictBase, group: str):
"""
A callback called at the end of training.
Args:
training_td (TensorDictBase): tensordict containing the loss values
group (str): group name
"""
pass
Expand All @@ -65,9 +80,20 @@ def on_batch_collected(self, batch: TensorDictBase):
for callback in self.callbacks:
callback.on_batch_collected(batch)

def on_train_end(self, training_td: TensorDictBase):
def on_train_step(self, batch: TensorDictBase, group: str) -> TensorDictBase:
train_td = None
for callback in self.callbacks:
td = callback.on_train_step(batch, group)
if td is not None:
if train_td is None:
train_td = td
else:
train_td.update(td)
return train_td

def on_train_end(self, training_td: TensorDictBase, group: str):
for callback in self.callbacks:
callback.on_train_end(training_td)
callback.on_train_end(training_td, group)

def on_evaluation_end(self, rollouts: List[TensorDictBase]):
for callback in self.callbacks:
Expand Down
18 changes: 12 additions & 6 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

import importlib

import os
import time
from collections import OrderedDict
Expand Down Expand Up @@ -535,7 +536,6 @@ def _collection_loop(self):
step=self.n_iters_performed,
)
pbar.set_description(f"mean return = {self.mean_return}", refresh=False)
pbar.update()

# Callback
self.on_batch_collected(batch)
Expand All @@ -561,7 +561,7 @@ def _collection_loop(self):
)

# Callback
self.on_train_end(training_td)
self.on_train_end(training_td, group)

# Exploration update
if isinstance(self.group_policies[group], TensorDictSequential):
Expand Down Expand Up @@ -607,6 +607,7 @@ def _collection_loop(self):
and self.total_frames % self.config.checkpoint_interval == 0
):
self._save_experiment()
pbar.update()
sampling_start = time.time()

self.close()
Expand Down Expand Up @@ -638,6 +639,7 @@ def _optimizer_loop(self, group: str) -> TensorDictBase:
loss_value.backward()

grad_norm = self._grad_clip(optimizer)

training_td.set(
f"grad_norm_{loss_name}",
torch.tensor(grad_norm, device=self.config.train_device),
Expand All @@ -648,6 +650,11 @@ def _optimizer_loop(self, group: str) -> TensorDictBase:
self.replay_buffers[group].update_tensordict_priority(subdata)
if self.target_updaters[group] is not None:
self.target_updaters[group].step()

callback_loss = self.on_train_step(subdata, group)
if callback_loss is not None:
training_td.update(callback_loss)

return training_td

def _grad_clip(self, optimizer: torch.optim.Optimizer) -> float:
Expand All @@ -672,10 +679,9 @@ def _evaluation_loop(self):
video_frames = []

def callback(env, td):
try:
video_frames.append(env.render(mode="rgb_array"))
except TypeError:
video_frames.append(env.render())
video_frames.append(
self.task.__class__.render_callback(self, env, td)
)

else:
video_frames = None
Expand Down
5 changes: 3 additions & 2 deletions benchmarl/experiment/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def log_collection(
total_frames: int,
step: int,
) -> float:

to_log = {}
json_metrics = {}
for group in self.group_map.keys():
Expand Down Expand Up @@ -218,7 +217,9 @@ def log_evaluation(
self.log(to_log, step=step)
if video_frames is not None:
vid = torch.tensor(
np.transpose(video_frames[: rollouts[0].batch_size[0]], (0, 3, 1, 2)),
np.transpose(
video_frames[: rollouts[0].batch_size[0] - 1], (0, 3, 1, 2)
),
dtype=torch.uint8,
).unsqueeze(0)
for logger in self.loggers:
Expand Down
9 changes: 6 additions & 3 deletions examples/callback/custom_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.experiment.callback import Callback
from benchmarl.models.mlp import MlpConfig
from tensordict import TensorDictBase
from tensordict import TensorDict, TensorDictBase


class MyCallbackA(Callback):
def on_batch_collected(self, batch: TensorDictBase):
print(f"Callback A is doing something with the sampling batch {batch}")

def on_train_end(self, training_td: TensorDictBase):
def on_train_step(self, batch: TensorDictBase, group: str) -> TensorDictBase:
print(f"Callback A is computing a loss with the training tensordict {batch}")
return TensorDict({}, [])

def on_train_end(self, training_td: TensorDictBase, group: str):
print(
f"Callback A is doing something with the training tensordict {training_td}"
)
Expand All @@ -37,7 +41,6 @@ def on_evaluation_end(self, rollouts: List[TensorDictBase]):


if __name__ == "__main__":

experiment_config = ExperimentConfig.get_from_yaml()
task = VmasTask.BALANCE.get_from_yaml()
algorithm_config = MappoConfig.get_from_yaml()
Expand Down

0 comments on commit 205db60

Please sign in to comment.