From 97bc1e046a65d1da87e89b241b7f930052b05729 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Sat, 12 Oct 2024 22:41:54 -0400 Subject: [PATCH] events are all you need? --- config.yaml | 2 +- pokemonred_puffer/sweep.py | 30 +++++++++++++++++++----------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/config.yaml b/config.yaml index 630568b..1a47b43 100644 --- a/config.yaml +++ b/config.yaml @@ -93,7 +93,7 @@ train: compile: True compile_mode: "reduce-overhead" float32_matmul_precision: "high" - total_timesteps: 600_000_000 # 100_000_000_000 for full games + total_timesteps: 500_000_000 # 100_000_000_000 for full games batch_size: 65536 minibatch_size: 2048 learning_rate: 2.0e-4 diff --git a/pokemonred_puffer/sweep.py b/pokemonred_puffer/sweep.py index b3ee3c6..4602841 100644 --- a/pokemonred_puffer/sweep.py +++ b/pokemonred_puffer/sweep.py @@ -25,7 +25,9 @@ app = typer.Typer(pretty_exceptions_enable=False) -def sweep_config_to_params(sweep_config: DictConfig, prefix: str = "") -> list[Param]: +def sweep_config_to_params( + base_config: DictConfig | int | float | bool | None, sweep_config: DictConfig, prefix: str = "" +) -> list[Param]: res = [] for k, v in sweep_config.items(): # A little hacky. Maybe I should not make this all config based @@ -35,13 +37,17 @@ def sweep_config_to_params(sweep_config: DictConfig, prefix: str = "") -> list[P Param( prefix.removesuffix("-").removeprefix("-"), param_class(**v), - (v["max"] + v["min"]) // 2 - if v.get("is_integer", False) - else math.sqrt(v["max"] * v["min"]), + base_config + if base_config is not None + else ( + (v["max"] + v["min"]) // 2 + if param_class == "LinearSpace" + else math.sqrt(v["max"] * v["min"]) + ), ) ] elif isinstance(v, DictConfig): - res += sweep_config_to_params(v, prefix=prefix + "-" + k) + res += sweep_config_to_params(base_config.get(k, None), v, prefix=prefix + "-" + k) else: print(type(v)) return res @@ -83,7 +89,9 @@ def launch_sweep( is_wandb_logging_enabled=False, wandb_params=WandbLoggingParams(project_name="Pokemon", run_name="Pokemon"), ) - params = sweep_config_to_params(sweep_config) + params = sweep_config_to_params(base_config, sweep_config) + breakpoint() + carbs = CARBS(config=config, params=params) sweep_id = wandb.sweep( sweep={ @@ -109,12 +117,12 @@ def launch_sweep( finished = set([]) while not sweep.done(): # Taken from sweep.schedule. Limits runs to only one at a time. - # if not (sweep._controller and sweep._controller.get("schedule")): # Only one run will be scheduled at a time - suggestion = carbs.suggest() - run = sweeps.SweepRun(config={"x": {"value": suggestion.suggestion}}) sweep._step() - sweep.schedule(run) + if not (sweep._controller and sweep._controller.get("schedule")): + suggestion = carbs.suggest() + run = sweeps.SweepRun(config={"x": {"value": suggestion.suggestion}}) + sweep.schedule(run) # without this nothing updates... sweep_obj = sweep._sweep_obj if runs := sweep_obj["runs"]: @@ -134,7 +142,7 @@ def launch_sweep( summary_metrics = json.loads(run["summaryMetrics"]) obs_in = ObservationInParam( input=json.loads(run["config"])["x"]["value"], - output=summary_metrics["environment/reward_sum"], + output=summary_metrics["environment/stats/event"], cost=summary_metrics["performance/uptime"], ) carbs.observe(obs_in)