Skip to content

Commit

Permalink
more wandb fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 21, 2024
1 parent 9cd294b commit d38e231
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def evaluate(self):
if (
hasattr(self.config, "swarm_frequency")
and hasattr(self.config, "swarm_keep_pct")
and self.update % self.config.swarm_frequency == 0
and self.epoch % self.config.swarm_frequency == 0
and "reward/event" in self.infos
):
# collect the top swarm_keep_pct % of envs
Expand Down Expand Up @@ -291,18 +291,16 @@ def evaluate(self):
# Moves into models... maybe. Definitely moves.
# You could also just return infos and have it in demo
if "pokemon_exploration_map" in self.infos and self.config.save_overlay is True:
if self.update % self.config.overlay_interval == 0:
if self.epoch % self.config.overlay_interval == 0:
overlay = make_pokemon_red_overlay(
np.stack(self.infos["pokemon_exploration_map"], axis=0)
)
if self.wandb_client is not None:
self.stats["Media/aggregate_exploration_map"] = self.wandb_client.Image(
overlay
)
self.stats["Media/aggregate_exploration_map"] = wandb.Image(overlay)

for k, v in self.infos.items():
if "_map" in k and self.wandb_client is not None:
self.stats[f"Media/{k}"] = self.wandb_client.Image(v[0])
self.stats[f"Media/{k}"] = wandb.Image(v[0])
continue
elif "state" in k:
pass
Expand Down Expand Up @@ -467,9 +465,9 @@ def train(self):
self.last_log_time = time.time()
self.wandb_client.log(
{
"0verview/SPS": self.profile.SPS,
"0verview/agent_steps": self.global_step,
"0verview/learning_rate": self.optimizer.param_groups[0]["lr"],
"Overview/SPS": self.profile.SPS,
"Overview/agent_steps": self.global_step,
"Overview/learning_rate": self.optimizer.param_groups[0]["lr"],
**{f"environment/{k}": v for k, v in self.stats.items()},
**{f"losses/{k}": v for k, v in self.losses.__dict__.items()},
**{f"performance/{k}": v for k, v in self.profile},
Expand All @@ -486,7 +484,7 @@ def close(self):

if self.wandb_client is not None:
artifact_name = f"{self.exp_name}_model"
artifact = self.wandb_client.Artifact(artifact_name, type="model")
artifact = wandb.Artifact(artifact_name, type="model")
model_path = self.save_checkpoint()
artifact.add_file(model_path)
self.wandb_client.run.log_artifact(artifact)
Expand Down

0 comments on commit d38e231

Please sign in to comment.