From b7374a0d4a01aacf5649704a405a0328c6b13984 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Sat, 26 Oct 2024 20:44:10 -0400 Subject: [PATCH] attempt to remove the x.value hack --- pokemonred_puffer/sweep.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pokemonred_puffer/sweep.py b/pokemonred_puffer/sweep.py index c8fa29f..20b0536 100644 --- a/pokemonred_puffer/sweep.py +++ b/pokemonred_puffer/sweep.py @@ -136,7 +136,9 @@ def launch_sweep( sweep._step() if not (sweep._controller and sweep._controller.get("schedule")): suggestion = carbs.suggest() - run = sweeps.SweepRun(config=suggestion.suggestion) + run = sweeps.SweepRun( + config={k: {"value": v} for k, v in suggestion.suggestion.items()} + ) sweep.schedule(run) # without this nothing updates... sweep_obj = sweep._sweep_obj @@ -161,7 +163,7 @@ def launch_sweep( and "performance/uptime" in summary_metrics ): obs_in = ObservationInParam( - input=json.loads(run["config"]), + input={k: v["value"] for k, v in json.loads(run["config"]).items()}, # TODO: try out other stats like required count output=summary_metrics["environment/stats/required_count"], cost=summary_metrics["performance/uptime"], @@ -185,7 +187,9 @@ def launch_agent( debug: bool = False, ): def _fn(): - agent_config: DictConfig = OmegaConf.load(os.environ["WANDB_SWEEP_PARAM_PATH"]) + agent_config: DictConfig = OmegaConf.create( + {k: v.value for k, v in OmegaConf.load(os.environ["WANDB_SWEEP_PARAM_PATH"]).items()} + ) agent_config = update_base_config(base_config, agent_config) train.train(config=agent_config, debug=debug, track=True)