Skip to content

Commit

Permalink
attempt to remove the x.value hack
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Oct 27, 2024
1 parent 8206f9c commit b7374a0
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions pokemonred_puffer/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def launch_sweep(
sweep._step()
if not (sweep._controller and sweep._controller.get("schedule")):
suggestion = carbs.suggest()
run = sweeps.SweepRun(config=suggestion.suggestion)
run = sweeps.SweepRun(
config={k: {"value": v} for k, v in suggestion.suggestion.items()}
)
sweep.schedule(run)
# without this nothing updates...
sweep_obj = sweep._sweep_obj
Expand All @@ -161,7 +163,7 @@ def launch_sweep(
and "performance/uptime" in summary_metrics
):
obs_in = ObservationInParam(
input=json.loads(run["config"]),
input={k: v["value"] for k, v in json.loads(run["config"]).items()},
# TODO: try out other stats like required count
output=summary_metrics["environment/stats/required_count"],
cost=summary_metrics["performance/uptime"],
Expand All @@ -185,7 +187,9 @@ def launch_agent(
debug: bool = False,
):
def _fn():
agent_config: DictConfig = OmegaConf.load(os.environ["WANDB_SWEEP_PARAM_PATH"])
agent_config: DictConfig = OmegaConf.create(
{k: v.value for k, v in OmegaConf.load(os.environ["WANDB_SWEEP_PARAM_PATH"]).items()}
)
agent_config = update_base_config(base_config, agent_config)
train.train(config=agent_config, debug=debug, track=True)

Expand Down

0 comments on commit b7374a0

Please sign in to comment.