Skip to content

Commit

Permalink
added save overlay/freq config
Browse files Browse the repository at this point in the history
  • Loading branch information
kywch committed Mar 12, 2024
1 parent 8917275 commit 090bbd0
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,4 @@ session_*
experiments
runs
video
wandb
4 changes: 4 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ debug:
total_timesteps: 1000
save_checkpoint: True
checkpoint_interval: 4
save_overlay: True
overlay_interval: 4
verbose: True
env_pool: False

Expand Down Expand Up @@ -73,6 +75,8 @@ train:
data_dir: runs
save_checkpoint: False
checkpoint_interval: 200
save_overlay: True
overlay_interval: 200
cpu_offload: True
pool_kernel: [0]

Expand Down
16 changes: 7 additions & 9 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,15 +481,13 @@ def evaluate(self):
self.stats = {}
self.max_stats = {}
for k, v in self.infos["learner"].items():
if "Task_eval_fn" in k:
# Temporary hack for NMMO competition
continue
if "pokemon_exploration_map" in k:
# self.exploration_map_agg[env_id, :, :] = v
# overlay = make_pokemon_red_overlay(self.exploration_map_agg)
overlay = make_pokemon_red_overlay(np.stack(v, axis=0))
if self.wandb is not None:
self.stats["Media/aggregate_exploration_map"] = self.wandb.Image(overlay)
if "pokemon_exploration_map" in k and config.save_overlay is True:
if self.update % config.overlay_interval == 0:
# self.exploration_map_agg[env_id, :, :] = v
# overlay = make_pokemon_red_overlay(self.exploration_map_agg)
overlay = make_pokemon_red_overlay(np.stack(v, axis=0))
if self.wandb is not None:
self.stats["Media/aggregate_exploration_map"] = self.wandb.Image(overlay)
try: # TODO: Better checks on log data types
self.stats[f"Histogram/{k}"] = self.wandb.Histogram(v, num_bins=16)
self.stats[k] = np.mean(v)
Expand Down

0 comments on commit 090bbd0

Please sign in to comment.