Skip to content

Commit

Permalink
[BugFix] Small fix to multi-group eval and add wandb project_name (#…
Browse files Browse the repository at this point in the history
…126)

* fix eval

* choose project name

* fix render
  • Loading branch information
matteobettini authored Sep 6, 2024
1 parent aef8d40 commit 308228e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
2 changes: 2 additions & 0 deletions benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ evaluation_deterministic_actions: True

# List of loggers to use, options are: wandb, csv, tensorboard, mflow
loggers: []
# Wandb project name
project_name: "benchmarl"
# Create a json folder as part of the output in the format of marl-eval
create_json: True

Expand Down
2 changes: 2 additions & 0 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class ExperimentConfig:
evaluation_deterministic_actions: bool = MISSING

loggers: List[str] = MISSING
project_name: str = MISSING
create_json: bool = MISSING

save_folder: Optional[str] = MISSING
Expand Down Expand Up @@ -535,6 +536,7 @@ def _setup_name(self):

def _setup_logger(self):
self.logger = Logger(
project_name=self.config.project_name,
experiment_name=self.name,
folder_name=str(self.folder_name),
experiment_config=self.config,
Expand Down
25 changes: 15 additions & 10 deletions benchmarl/experiment/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
model_name: str,
group_map: Dict[str, List[str]],
seed: int,
project_name: str,
):
self.experiment_config = experiment_config
self.algorithm_name = algorithm_name
Expand Down Expand Up @@ -63,7 +64,7 @@ def __init__(
experiment_name=experiment_name,
wandb_kwargs={
"group": task_name,
"project": "benchmarl",
"project": project_name,
"id": experiment_name,
},
)
Expand Down Expand Up @@ -165,9 +166,11 @@ def log_evaluation(
return
to_log = {}
json_metrics = {}
max_length_rollout_0 = 0
for group in self.group_map.keys():
# Cut the rollouts at the first done
for k, r in enumerate(rollouts):
rollouts_group = []
for i, r in enumerate(rollouts):
next_done = self._get_done(group, r)
# Reduce it to batch size
next_done = next_done.sum(
Expand All @@ -178,19 +181,23 @@ def log_evaluation(
done_index = next_done.nonzero(as_tuple=True)[0]
if done_index.numel() > 0:
done_index = done_index[0]
rollouts[k] = r[: done_index + 1]
r = r[: done_index + 1]
if i == 0:
max_length_rollout_0 = max(r.batch_size[0], max_length_rollout_0)
rollouts_group.append(r)

returns = [
self._get_reward(group, td).sum(0).mean().item() for td in rollouts
self._get_reward(group, td).sum(0).mean().item()
for td in rollouts_group
]
json_metrics[group + "_return"] = torch.tensor(
returns, device=rollouts[0].device
returns, device=rollouts_group[0].device
)
to_log.update(
{
f"eval/{group}/reward/episode_reward_min": min(returns),
f"eval/{group}/reward/episode_reward_mean": sum(returns)
/ len(rollouts),
/ len(rollouts_group),
f"eval/{group}/reward/episode_reward_max": max(returns),
}
)
Expand Down Expand Up @@ -223,10 +230,8 @@ def log_evaluation(
)

self.log(to_log, step=step)
if video_frames is not None and rollouts[0].batch_size[0] > 1:
video_frames = np.stack(
video_frames[: rollouts[0].batch_size[0] - 1], axis=0
)
if video_frames is not None and max_length_rollout_0 > 1:
video_frames = np.stack(video_frames[: max_length_rollout_0 - 1], axis=0)
vid = torch.tensor(
np.transpose(video_frames, (0, 3, 1, 2)),
dtype=torch.uint8,
Expand Down

0 comments on commit 308228e

Please sign in to comment.