Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] More callbacks #35

Merged
merged 5 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions benchmarl/environments/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ def get_reward_sum_transform(self, env: EnvBase) -> Transform:
"""
return RewardSum(reset_keys=env.reset_keys)

@staticmethod
def render_callback(experiment, env: EnvBase, data: TensorDictBase):
try:
return env.render(mode="rgb_array")
except TypeError:
return env.render()

def __repr__(self):
cls_name = self.__class__.__name__
return f"{cls_name}.{self.name}: (config={self.config})"
Expand Down
34 changes: 30 additions & 4 deletions benchmarl/experiment/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,27 @@ def on_batch_collected(self, batch: TensorDictBase):
"""
pass

def on_train_end(self, training_td: TensorDictBase):
def on_train_step(self, batch: TensorDictBase, group: str) -> TensorDictBase:
"""
A callback called at the end of every training step.
A callback called for every training step.
Args:
batch (TensorDictBase): tensordict with the training batch
group (str): group name
Returns:
TensorDictBase: a new tensordict containing the loss values
"""
pass

def on_train_end(self, training_td: TensorDictBase, group: str):
"""
A callback called at the end of training.
Args:
training_td (TensorDictBase): tensordict containing the loss values
group (str): group name
"""
pass
Expand All @@ -65,9 +80,20 @@ def on_batch_collected(self, batch: TensorDictBase):
for callback in self.callbacks:
callback.on_batch_collected(batch)

def on_train_end(self, training_td: TensorDictBase):
def on_train_step(self, batch: TensorDictBase, group: str) -> TensorDictBase:
train_td = None
for callback in self.callbacks:
td = callback.on_train_step(batch, group)
if td is not None:
if train_td is None:
train_td = td
else:
train_td.update(td)
return train_td

def on_train_end(self, training_td: TensorDictBase, group: str):
for callback in self.callbacks:
callback.on_train_end(training_td)
callback.on_train_end(training_td, group)

def on_evaluation_end(self, rollouts: List[TensorDictBase]):
for callback in self.callbacks:
Expand Down
18 changes: 12 additions & 6 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

import importlib

import os
import time
from collections import OrderedDict
Expand Down Expand Up @@ -535,7 +536,6 @@ def _collection_loop(self):
step=self.n_iters_performed,
)
pbar.set_description(f"mean return = {self.mean_return}", refresh=False)
pbar.update()

# Callback
self.on_batch_collected(batch)
Expand All @@ -561,7 +561,7 @@ def _collection_loop(self):
)

# Callback
self.on_train_end(training_td)
self.on_train_end(training_td, group)

# Exploration update
if isinstance(self.group_policies[group], TensorDictSequential):
Expand Down Expand Up @@ -607,6 +607,7 @@ def _collection_loop(self):
and self.total_frames % self.config.checkpoint_interval == 0
):
self._save_experiment()
pbar.update()
sampling_start = time.time()

self.close()
Expand Down Expand Up @@ -638,6 +639,7 @@ def _optimizer_loop(self, group: str) -> TensorDictBase:
loss_value.backward()

grad_norm = self._grad_clip(optimizer)

training_td.set(
f"grad_norm_{loss_name}",
torch.tensor(grad_norm, device=self.config.train_device),
Expand All @@ -648,6 +650,11 @@ def _optimizer_loop(self, group: str) -> TensorDictBase:
self.replay_buffers[group].update_tensordict_priority(subdata)
if self.target_updaters[group] is not None:
self.target_updaters[group].step()

callback_loss = self.on_train_step(subdata, group)
if callback_loss is not None:
training_td.update(callback_loss)

return training_td

def _grad_clip(self, optimizer: torch.optim.Optimizer) -> float:
Expand All @@ -672,10 +679,9 @@ def _evaluation_loop(self):
video_frames = []

def callback(env, td):
try:
video_frames.append(env.render(mode="rgb_array"))
except TypeError:
video_frames.append(env.render())
video_frames.append(
self.task.__class__.render_callback(self, env, td)
)

else:
video_frames = None
Expand Down
5 changes: 3 additions & 2 deletions benchmarl/experiment/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def log_collection(
total_frames: int,
step: int,
) -> float:

to_log = {}
json_metrics = {}
for group in self.group_map.keys():
Expand Down Expand Up @@ -218,7 +217,9 @@ def log_evaluation(
self.log(to_log, step=step)
if video_frames is not None:
vid = torch.tensor(
np.transpose(video_frames[: rollouts[0].batch_size[0]], (0, 3, 1, 2)),
np.transpose(
video_frames[: rollouts[0].batch_size[0] - 1], (0, 3, 1, 2)
),
dtype=torch.uint8,
).unsqueeze(0)
for logger in self.loggers:
Expand Down
9 changes: 6 additions & 3 deletions examples/callback/custom_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.experiment.callback import Callback
from benchmarl.models.mlp import MlpConfig
from tensordict import TensorDictBase
from tensordict import TensorDict, TensorDictBase


class MyCallbackA(Callback):
def on_batch_collected(self, batch: TensorDictBase):
print(f"Callback A is doing something with the sampling batch {batch}")

def on_train_end(self, training_td: TensorDictBase):
def on_train_step(self, batch: TensorDictBase, group: str) -> TensorDictBase:
print(f"Callback A is computing a loss with the training tensordict {batch}")
return TensorDict({}, [])

def on_train_end(self, training_td: TensorDictBase, group: str):
print(
f"Callback A is doing something with the training tensordict {training_td}"
)
Expand All @@ -37,7 +41,6 @@ def on_evaluation_end(self, rollouts: List[TensorDictBase]):


if __name__ == "__main__":

experiment_config = ExperimentConfig.get_from_yaml()
task = VmasTask.BALANCE.get_from_yaml()
algorithm_config = MappoConfig.get_from_yaml()
Expand Down
Loading