diff --git a/pokemonred_puffer/sweep.py b/pokemonred_puffer/sweep.py index 0f95e7b..56e02eb 100644 --- a/pokemonred_puffer/sweep.py +++ b/pokemonred_puffer/sweep.py @@ -98,6 +98,7 @@ def launch_sweep( ): console = Console() params = sweep_config_to_params(base_config, sweep_config) + params_keys = {p.name for p in params} for param in params: print(f"Checking param: {param}") if isinstance(param.space, LogitSpace): @@ -256,7 +257,7 @@ def launch_sweep( input={ k: v["value"] for k, v in json.loads(run["config"]).items() - if k != "wandb_version" + if k in params_keys }, # TODO: try out other stats like required count output=summary_metrics["environment/stats/required_count"],