Skip to content

Commit

Permalink
events are all you need?
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Oct 13, 2024
1 parent eb3c977 commit 97bc1e0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
2 changes: 1 addition & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ train:
compile: True
compile_mode: "reduce-overhead"
float32_matmul_precision: "high"
total_timesteps: 600_000_000 # 100_000_000_000 for full games
total_timesteps: 500_000_000 # 100_000_000_000 for full games
batch_size: 65536
minibatch_size: 2048
learning_rate: 2.0e-4
Expand Down
30 changes: 19 additions & 11 deletions pokemonred_puffer/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
app = typer.Typer(pretty_exceptions_enable=False)


def sweep_config_to_params(sweep_config: DictConfig, prefix: str = "") -> list[Param]:
def sweep_config_to_params(
base_config: DictConfig | int | float | bool | None, sweep_config: DictConfig, prefix: str = ""
) -> list[Param]:
res = []
for k, v in sweep_config.items():
# A little hacky. Maybe I should not make this all config based
Expand All @@ -35,13 +37,17 @@ def sweep_config_to_params(sweep_config: DictConfig, prefix: str = "") -> list[P
Param(
prefix.removesuffix("-").removeprefix("-"),
param_class(**v),
(v["max"] + v["min"]) // 2
if v.get("is_integer", False)
else math.sqrt(v["max"] * v["min"]),
base_config
if base_config is not None
else (
(v["max"] + v["min"]) // 2
if param_class == "LinearSpace"
else math.sqrt(v["max"] * v["min"])
),
)
]
elif isinstance(v, DictConfig):
res += sweep_config_to_params(v, prefix=prefix + "-" + k)
res += sweep_config_to_params(base_config.get(k, None), v, prefix=prefix + "-" + k)
else:
print(type(v))
return res
Expand Down Expand Up @@ -83,7 +89,9 @@ def launch_sweep(
is_wandb_logging_enabled=False,
wandb_params=WandbLoggingParams(project_name="Pokemon", run_name="Pokemon"),
)
params = sweep_config_to_params(sweep_config)
params = sweep_config_to_params(base_config, sweep_config)
breakpoint()

carbs = CARBS(config=config, params=params)
sweep_id = wandb.sweep(
sweep={
Expand All @@ -109,12 +117,12 @@ def launch_sweep(
finished = set([])
while not sweep.done():
# Taken from sweep.schedule. Limits runs to only one at a time.
# if not (sweep._controller and sweep._controller.get("schedule")):
# Only one run will be scheduled at a time
suggestion = carbs.suggest()
run = sweeps.SweepRun(config={"x": {"value": suggestion.suggestion}})
sweep._step()
sweep.schedule(run)
if not (sweep._controller and sweep._controller.get("schedule")):
suggestion = carbs.suggest()
run = sweeps.SweepRun(config={"x": {"value": suggestion.suggestion}})
sweep.schedule(run)
# without this nothing updates...
sweep_obj = sweep._sweep_obj
if runs := sweep_obj["runs"]:
Expand All @@ -134,7 +142,7 @@ def launch_sweep(
summary_metrics = json.loads(run["summaryMetrics"])
obs_in = ObservationInParam(
input=json.loads(run["config"])["x"]["value"],
output=summary_metrics["environment/reward_sum"],
output=summary_metrics["environment/stats/event"],
cost=summary_metrics["performance/uptime"],
)
carbs.observe(obs_in)
Expand Down

0 comments on commit 97bc1e0

Please sign in to comment.