diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index ec21099..31438e7 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -206,7 +206,7 @@ def evaluate(self): if ( hasattr(self.config, "swarm_frequency") and hasattr(self.config, "swarm_keep_pct") - and self.update % self.config.swarm_frequency == 0 + and self.epoch % self.config.swarm_frequency == 0 and "reward/event" in self.infos ): # collect the top swarm_keep_pct % of envs @@ -291,18 +291,16 @@ def evaluate(self): # Moves into models... maybe. Definitely moves. # You could also just return infos and have it in demo if "pokemon_exploration_map" in self.infos and self.config.save_overlay is True: - if self.update % self.config.overlay_interval == 0: + if self.epoch % self.config.overlay_interval == 0: overlay = make_pokemon_red_overlay( np.stack(self.infos["pokemon_exploration_map"], axis=0) ) if self.wandb_client is not None: - self.stats["Media/aggregate_exploration_map"] = self.wandb_client.Image( - overlay - ) + self.stats["Media/aggregate_exploration_map"] = wandb.Image(overlay) for k, v in self.infos.items(): if "_map" in k and self.wandb_client is not None: - self.stats[f"Media/{k}"] = self.wandb_client.Image(v[0]) + self.stats[f"Media/{k}"] = wandb.Image(v[0]) continue elif "state" in k: pass @@ -467,9 +465,9 @@ def train(self): self.last_log_time = time.time() self.wandb_client.log( { - "0verview/SPS": self.profile.SPS, - "0verview/agent_steps": self.global_step, - "0verview/learning_rate": self.optimizer.param_groups[0]["lr"], + "Overview/SPS": self.profile.SPS, + "Overview/agent_steps": self.global_step, + "Overview/learning_rate": self.optimizer.param_groups[0]["lr"], **{f"environment/{k}": v for k, v in self.stats.items()}, **{f"losses/{k}": v for k, v in self.losses.__dict__.items()}, **{f"performance/{k}": v for k, v in self.profile}, @@ -486,7 +484,7 @@ def close(self): if self.wandb_client is not None: artifact_name = f"{self.exp_name}_model" - artifact = self.wandb_client.Artifact(artifact_name, type="model") + artifact = wandb.Artifact(artifact_name, type="model") model_path = self.save_checkpoint() artifact.add_file(model_path) self.wandb_client.run.log_artifact(artifact)