diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index c2e78cc..88509ce 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -1431,6 +1431,9 @@ def agent_stats(self, action): "max_steps": self.get_max_steps(), # redundant but this is so we don't interfere with the swarm logic "required_count": len(self.required_events) + len(self.required_items), + "event_plus_required": self.progress_reward["event"] + + len(self.required_events) + + len(self.required_items), } | { "exploration": { diff --git a/pokemonred_puffer/sweep.py b/pokemonred_puffer/sweep.py index 30dee40..d9a8d94 100644 --- a/pokemonred_puffer/sweep.py +++ b/pokemonred_puffer/sweep.py @@ -108,7 +108,7 @@ def launch_sweep( "controller": {"type": "local"}, "parameters": {p.name: {"min": p.space.min, "max": p.space.max} for p in params}, "metric": { - "name": "environment/stats/required_count", + "name": "environment/stats/event_plus_required", "goal": "maximize", "goal_value": 100, }, @@ -170,7 +170,7 @@ def launch_sweep( if k != "wandb_version" }, # TODO: try out other stats like required count - output=summary_metrics["environment/stats/required_count"], + output=summary_metrics["environment/stats/event_plus_required"], cost=summary_metrics["performance/uptime"], ) carbs.observe(obs_in)