Skip to content

Commit

Permalink
Resumable carbs
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Oct 13, 2024
1 parent e66c181 commit 4182d4a
Showing 1 changed file with 31 additions and 18 deletions.
49 changes: 31 additions & 18 deletions pokemonred_puffer/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,26 +82,36 @@ def launch_sweep(
),
] = "sweep-config.yaml",
sweep_name: Annotated[str, typer.Option(help="Sweep name")] = "PokeSweep",
sweep_id: Annotated[
str | None,
typer.Option(
help="Sweep id to use. If specified, a previous sweep will be resumed. "
"N.B. The sweep and base config MUST BE THE SAME"
),
] = None,
):
console = Console()
config = CARBSParams(
better_direction_sign=1,
is_wandb_logging_enabled=False,
wandb_params=WandbLoggingParams(project_name="Pokemon", run_name="Pokemon"),
)
params = sweep_config_to_params(base_config, sweep_config)

carbs = CARBS(config=config, params=params)
sweep_id = wandb.sweep(
sweep={
"name": sweep_name,
"controller": {"type": "local"},
"parameters": {},
"command": ["${args_json}"],
},
entity=base_config.wandb.entity,
project=base_config.wandb.project,
)
if not sweep_id:
config = CARBSParams(
better_direction_sign=1,
is_wandb_logging_enabled=False,
wandb_params=WandbLoggingParams(project_name="Pokemon", run_name=sweep_id),
)
params = sweep_config_to_params(base_config, sweep_config)

carbs = CARBS(config=config, params=params)
sweep_id = wandb.sweep(
sweep={
"name": sweep_name,
"controller": {"type": "local"},
"parameters": {},
"command": ["${args_json}"],
},
entity=base_config.wandb.entity,
project=base_config.wandb.project,
)
else:
carbs = CARBS.warm_start_from_wandb(run_name=sweep_id, is_prior_observation_valid=True)

import pprint

Expand Down Expand Up @@ -150,6 +160,9 @@ def launch_sweep(
cost=summary_metrics["performance/uptime"],
)
carbs.observe(obs_in)
# Because wandb stages the artifacts we have to keep these files
# dangling around wasting good disk space.
carbs.save_to_file(hash(finished) + ".pt", upload_to_wandb=True)
elif run["state"] == RunState.pending:
print(f"PENDING RUN FOUND {run['name']}")
sweep.print_status()
Expand Down

0 comments on commit 4182d4a

Please sign in to comment.