From 15b3a785fae323989c246d971442e652b8ea823c Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Sun, 10 Nov 2024 14:29:19 -0500 Subject: [PATCH] Just rely on local saving --- pokemonred_puffer/sweep.py | 104 ++++++++++++++++++++++++------------- 1 file changed, 69 insertions(+), 35 deletions(-) diff --git a/pokemonred_puffer/sweep.py b/pokemonred_puffer/sweep.py index c112d54..6d2117e 100644 --- a/pokemonred_puffer/sweep.py +++ b/pokemonred_puffer/sweep.py @@ -3,6 +3,7 @@ import multiprocessing as mp import os import pprint +import re from typing import Annotated import carbs.utils @@ -128,51 +129,84 @@ def launch_sweep( f"< {param.space.max}" ) - config = CARBSParams( - better_direction_sign=1, - is_wandb_logging_enabled=True, - wandb_params=WandbLoggingParams( - project_name="Pokemon", - run_id=sweep_id, - root_dir="carbs", - ), - resample_frequency=5, - num_random_samples=len(params), - ) + if sweep_id: + config = CARBSParams( + better_direction_sign=1, + is_wandb_logging_enabled=True, + wandb_params=WandbLoggingParams( + project_name="Pokemon", + run_id=sweep_id, + run_name=sweep_name, + ), + resample_frequency=5, + num_random_samples=len(params), + checkpoint_dir=f"carbs/checkpoints/{sweep_id}", + ) - carbs = CARBS(config=config, params=params) - if not dry_run: - if sweep_id: - # runname = entity/project/run_id - carbs.warm_start_from_wandb( - run_name=f"{base_config.wandb.entity}/{base_config.wandb.project}/{sweep_id}", + carbs = CARBS(config=config, params=params) + # for wandb + # runname = entity/project/run_id + # find most recent file in checkpoint dir + experiment_dir = f"carbs/checkpoints/{sweep_id}/{sweep_name}" + saves = [ + save_filename + for save_filename in os.listdir(experiment_dir) + if re.match(r"carbs_\d+obs.pt", save_filename) + ] + + if saves: + # sort by the middle int and take the highest value + # dont need split, could also use a regex group + save_filename = sorted( + saves, key=lambda x: int(x.replace("carbs_", "").replace("obs.pt", "")) + )[0] + carbs.warm_start( + filename=os.path.join(experiment_dir, save_filename), is_prior_observation_valid=True, ) - else: - sweep_id = wandb.sweep( - sweep={ - "name": sweep_name, - "controller": {"type": "local"}, - "parameters": { - p.name: {"min": p.space.min, "max": p.space.max} for p in params - }, - "metric": { - "name": "environment/stats/required_count", - "goal": "maximize", - "goal_value": 100, - }, - "command": ["${args_json}"], + + if not sweep_id and not dry_run: + sweep_id = wandb.sweep( + sweep={ + "name": sweep_name, + "controller": {"type": "local"}, + "parameters": {p.name: {"min": p.space.min, "max": p.space.max} for p in params}, + "metric": { + "name": "environment/stats/required_count", + "goal": "maximize", }, - entity=base_config.wandb.entity, - project=base_config.wandb.project, - ) + "command": ["${args_json}"], + }, + entity=base_config.wandb.entity, + project=base_config.wandb.project, + ) + config = CARBSParams( + better_direction_sign=1, + is_wandb_logging_enabled=True, + wandb_params=WandbLoggingParams( + project_name="Pokemon", + run_id=sweep_id, + run_name=sweep_name, + ), + resample_frequency=5, + num_random_samples=len(params), + checkpoint_dir=f"carbs/checkpoints/{sweep_id}", + ) + carbs = CARBS(config=config, params=params) + os.makedirs(os.path.join(config.checkpoint_dir, carbs.experiment_name), exist_ok=True) + + carbs._autosave() pprint.pprint(params) if dry_run: carbs.suggest() return - sweep = wandb.controller(sweep_id) + sweep = wandb.controller( + sweep_id_or_config=sweep_id, + entity=base_config.wandb.entity, + project=base_config.wandb.project, + ) console.print(f"Beginning sweep with id {sweep_id}", style="bold") console.print(