From 7272fecc1174e5f28d9d1886e48c2919ca523134 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Sat, 9 Nov 2024 21:10:52 -0500 Subject: [PATCH] fix wandb saving --- pokemonred_puffer/sweep.py | 226 +++++++++++++++++++------------------ 1 file changed, 115 insertions(+), 111 deletions(-) diff --git a/pokemonred_puffer/sweep.py b/pokemonred_puffer/sweep.py index c2eea25..9255899 100644 --- a/pokemonred_puffer/sweep.py +++ b/pokemonred_puffer/sweep.py @@ -2,26 +2,27 @@ import math import multiprocessing as mp import os +import pprint from typing import Annotated import carbs.utils import sweeps -from sweeps import RunState import typer from carbs import ( CARBS, + CARBSParams, LogitSpace, + ObservationInParam, Param, ParamDictType, ParamType, - CARBSParams, WandbLoggingParams, - ObservationInParam, ) from omegaconf import DictConfig, OmegaConf from rich.console import Console -import wandb +from sweeps import RunState +import wandb from pokemonred_puffer import train from pokemonred_puffer.environment import RedGymEnv @@ -95,49 +96,55 @@ def launch_sweep( dry_run: Annotated[bool, typer.Option(help="Attempts to start CARBS, but not wandb")] = False, ): console = Console() - if not sweep_id: - params = sweep_config_to_params(base_config, sweep_config) - for param in params: - print(f"Checking param: {param}") - if isinstance(param.space, LogitSpace): - assert ( - 0.0 - <= param.space.min - < param.search_center - param.space.scale - < param.search_center + param.space.scale - < param.space.max - <= 1.0 - ), ( - "0.0 " - f"<= {param.space.min} " - f"< {param.search_center} - {param.space.scale} " - f"< {param.search_center} + {param.space.scale} " - f"< {param.space.max} " - f"<= 1.0" - ) - else: - assert ( - param.space.min - < param.search_center - param.space.scale - < param.search_center + param.space.scale - < param.space.max - ), ( - f"{param.space.min} " - f"< {param.search_center} - {param.space.scale} " - f"< {param.search_center} + {param.space.scale} " - f"< {param.space.max}" - ) + params = sweep_config_to_params(base_config, sweep_config) + for param in params: + print(f"Checking param: {param}") + if isinstance(param.space, LogitSpace): + assert ( + 0.0 + <= param.space.min + < param.search_center - param.space.scale + < param.search_center + param.space.scale + < param.space.max + <= 1.0 + ), ( + "0.0 " + f"<= {param.space.min} " + f"< {param.search_center} - {param.space.scale} " + f"< {param.search_center} + {param.space.scale} " + f"< {param.space.max} " + f"<= 1.0" + ) + else: + assert ( + param.space.min + < param.search_center - param.space.scale + < param.search_center + param.space.scale + < param.space.max + ), ( + f"{param.space.min} " + f"< {param.search_center} - {param.space.scale} " + f"< {param.search_center} + {param.space.scale} " + f"< {param.space.max}" + ) - config = CARBSParams( - better_direction_sign=1, - is_wandb_logging_enabled=False, - wandb_params=WandbLoggingParams(project_name="Pokemon", run_name=sweep_id), - resample_frequency=5, - num_random_samples=len(params), - ) + config = CARBSParams( + better_direction_sign=1, + is_wandb_logging_enabled=True, + wandb_params=WandbLoggingParams( + project_name="Pokemon", + run_id=sweep_id, + root_dir="carbs", + ), + resample_frequency=5, + num_random_samples=len(params), + ) - carbs = CARBS(config=config, params=params) - if not dry_run: + carbs = CARBS(config=config, params=params) + if not dry_run: + if sweep_id: + carbs.warm_start_from_wandb(run_name=sweep_id, is_prior_observation_valid=True) + else: sweep_id = wandb.sweep( sweep={ "name": sweep_name, @@ -155,76 +162,73 @@ def launch_sweep( entity=base_config.wandb.entity, project=base_config.wandb.project, ) - else: - carbs = CARBS.warm_start_from_wandb(run_name=sweep_id, is_prior_observation_valid=True) - - import pprint pprint.pprint(params) - if not dry_run: - sweep = wandb.controller(sweep_id) + if dry_run: + carbs.suggest() + return - console.print(f"Beginning sweep with id {sweep_id}", style="bold") - console.print( - f"On all nodes please run python -m pokemonred_puffer.sweep launch-agent {sweep_id}", - style="bold", - ) + sweep = wandb.controller(sweep_id) - finished = [] - while not sweep.done(): - # Taken from sweep.schedule. Limits runs to only one at a time. - # Only one run will be scheduled at a time - sweep._step() - if not (sweep._controller and sweep._controller.get("schedule")): - suggestion = carbs.suggest() - run = sweeps.SweepRun( - config={k: {"value": v} for k, v in suggestion.suggestion.items()} - ) - sweep.schedule(run) - # without this nothing updates... - sweep_obj = sweep._sweep_obj - if runs := sweep_obj["runs"]: - for run in runs: - if run["state"] == RunState.running.value: - pass - elif ( - run["state"] - in [ - RunState.failed.value, - RunState.finished.value, - RunState.crashed.value, - ] - and run["name"] not in finished - ): - finished.append(run["name"]) - if summaryMetrics_json := run.get("summaryMetrics", None): - summary_metrics = json.loads(summaryMetrics_json) - if ( - "environment/stats/required_count" in summary_metrics - and "performance/uptime" in summary_metrics - # Only count agents that have run more than 1M steps - and "Overview/agent_steps" in summary_metrics - and summary_metrics["Overview/agent_steps"] > 1e6 - ): - obs_in = ObservationInParam( - input={ - k: v["value"] - for k, v in json.loads(run["config"]).items() - if k != "wandb_version" - }, - # TODO: try out other stats like required count - output=summary_metrics["environment/stats/required_count"], - cost=summary_metrics["performance/uptime"], - ) - carbs.observe(obs_in) - # Because wandb stages the artifacts we have to keep these files - # dangling around wasting good disk space. - # carbs.save_to_file(hash(tuple(finished)) + ".pt", upload_to_wandb=True) - elif run["state"] == RunState.pending: - print(f"PENDING RUN FOUND {run['name']}") - sweep.print_status() - else: - suggestion = carbs.suggest() + console.print(f"Beginning sweep with id {sweep_id}", style="bold") + console.print( + f"On all nodes please run python -m pokemonred_puffer.sweep launch-agent {sweep_id}", + style="bold", + ) + + finished = [] + while not sweep.done(): + # Taken from sweep.schedule. Limits runs to only one at a time. + # Only one run will be scheduled at a time + sweep._step() + if not (sweep._controller and sweep._controller.get("schedule")): + suggestion = carbs.suggest() + run = sweeps.SweepRun( + config={k: {"value": v} for k, v in suggestion.suggestion.items()} + ) + sweep.schedule(run) + # without this nothing updates... + sweep_obj = sweep._sweep_obj + if runs := sweep_obj["runs"]: + for run in runs: + if run["state"] == RunState.running.value: + pass + elif ( + run["state"] + in [ + RunState.failed.value, + RunState.finished.value, + RunState.crashed.value, + ] + and run["name"] not in finished + ): + finished.append(run["name"]) + if summaryMetrics_json := run.get("summaryMetrics", None): + summary_metrics = json.loads(summaryMetrics_json) + if ( + "environment/stats/required_count" in summary_metrics + and "performance/uptime" in summary_metrics + # Only count agents that have run more than 1M steps + and "Overview/agent_steps" in summary_metrics + and summary_metrics["Overview/agent_steps"] > 1e6 + ): + obs_in = ObservationInParam( + input={ + k: v["value"] + for k, v in json.loads(run["config"]).items() + if k != "wandb_version" + }, + # TODO: try out other stats like required count + output=summary_metrics["environment/stats/required_count"], + cost=summary_metrics["performance/uptime"], + ) + carbs.observe(obs_in) + # Because wandb stages the artifacts we have to keep these files + # dangling around wasting good disk space. + # carbs.save_to_file(hash(tuple(finished)) + ".pt", upload_to_wandb=True) + elif run["state"] == RunState.pending: + print(f"PENDING RUN FOUND {run['name']}") + sweep.print_status() @app.command()