diff --git a/pokemonred_puffer/sweep.py b/pokemonred_puffer/sweep.py index 31ac840..e4f5b4a 100644 --- a/pokemonred_puffer/sweep.py +++ b/pokemonred_puffer/sweep.py @@ -91,10 +91,25 @@ def launch_sweep( "N.B. The sweep and base config MUST BE THE SAME" ), ] = None, + dry_run: Annotated[bool, typer.Option(help="Attempts to start CARBS, but not wandb")] = False, ): console = Console() if not sweep_id: params = sweep_config_to_params(base_config, sweep_config) + for param in params: + print(f"Checking param: {param}") + assert ( + param.space.min + < param.search_center - param.space.scale + < param.search_center + param.space.scale + < param.space.max + ), ( + f"{param.space.min} < " + f"{param.search_center} - {param.space.scale} " + f"< {param.search_center} + {param.space.scale} " + f"< {param.space.max}" + ) + config = CARBSParams( better_direction_sign=1, is_wandb_logging_enabled=False, @@ -104,84 +119,91 @@ def launch_sweep( ) carbs = CARBS(config=config, params=params) - sweep_id = wandb.sweep( - sweep={ - "name": sweep_name, - "controller": {"type": "local"}, - "parameters": {p.name: {"min": p.space.min, "max": p.space.max} for p in params}, - "metric": { - "name": "environment/stats/required_count", - "goal": "maximize", - "goal_value": 100, + if not dry_run: + sweep_id = wandb.sweep( + sweep={ + "name": sweep_name, + "controller": {"type": "local"}, + "parameters": { + p.name: {"min": p.space.min, "max": p.space.max} for p in params + }, + "metric": { + "name": "environment/stats/required_count", + "goal": "maximize", + "goal_value": 100, + }, + "command": ["${args_json}"], }, - "command": ["${args_json}"], - }, - entity=base_config.wandb.entity, - project=base_config.wandb.project, - ) + entity=base_config.wandb.entity, + project=base_config.wandb.project, + ) else: carbs = CARBS.warm_start_from_wandb(run_name=sweep_id, is_prior_observation_valid=True) import pprint pprint.pprint(params) - sweep = wandb.controller(sweep_id) - - console.print(f"Beginning sweep with id {sweep_id}", style="bold") - console.print( - f"On all nodes please run python -m pokemonred_puffer.sweep launch-agent {sweep_id}", - style="bold", - ) - finished = [] - while not sweep.done(): - # Taken from sweep.schedule. Limits runs to only one at a time. - # Only one run will be scheduled at a time - sweep._step() - if not (sweep._controller and sweep._controller.get("schedule")): - suggestion = carbs.suggest() - run = sweeps.SweepRun( - config={k: {"value": v} for k, v in suggestion.suggestion.items()} - ) - sweep.schedule(run) - # without this nothing updates... - sweep_obj = sweep._sweep_obj - if runs := sweep_obj["runs"]: - for run in runs: - if run["state"] == RunState.running.value: - pass - elif ( - run["state"] - in [ - RunState.failed.value, - RunState.finished.value, - RunState.crashed.value, - ] - and run["name"] not in finished - ): - finished.append(run["name"]) - if summaryMetrics_json := run.get("summaryMetrics", None): - summary_metrics = json.loads(summaryMetrics_json) - if ( - "environment/stats/required_count" in summary_metrics - and "performance/uptime" in summary_metrics - ): - obs_in = ObservationInParam( - input={ - k: v["value"] - for k, v in json.loads(run["config"]).items() - if k != "wandb_version" - }, - # TODO: try out other stats like required count - output=summary_metrics["environment/stats/required_count"], - cost=summary_metrics["performance/uptime"], - ) - carbs.observe(obs_in) - # Because wandb stages the artifacts we have to keep these files - # dangling around wasting good disk space. - # carbs.save_to_file(hash(tuple(finished)) + ".pt", upload_to_wandb=True) - elif run["state"] == RunState.pending: - print(f"PENDING RUN FOUND {run['name']}") - sweep.print_status() + if not dry_run: + sweep = wandb.controller(sweep_id) + + console.print(f"Beginning sweep with id {sweep_id}", style="bold") + console.print( + f"On all nodes please run python -m pokemonred_puffer.sweep launch-agent {sweep_id}", + style="bold", + ) + + finished = [] + while not sweep.done(): + # Taken from sweep.schedule. Limits runs to only one at a time. + # Only one run will be scheduled at a time + sweep._step() + if not (sweep._controller and sweep._controller.get("schedule")): + suggestion = carbs.suggest() + run = sweeps.SweepRun( + config={k: {"value": v} for k, v in suggestion.suggestion.items()} + ) + sweep.schedule(run) + # without this nothing updates... + sweep_obj = sweep._sweep_obj + if runs := sweep_obj["runs"]: + for run in runs: + if run["state"] == RunState.running.value: + pass + elif ( + run["state"] + in [ + RunState.failed.value, + RunState.finished.value, + RunState.crashed.value, + ] + and run["name"] not in finished + ): + finished.append(run["name"]) + if summaryMetrics_json := run.get("summaryMetrics", None): + summary_metrics = json.loads(summaryMetrics_json) + if ( + "environment/stats/required_count" in summary_metrics + and "performance/uptime" in summary_metrics + ): + obs_in = ObservationInParam( + input={ + k: v["value"] + for k, v in json.loads(run["config"]).items() + if k != "wandb_version" + }, + # TODO: try out other stats like required count + output=summary_metrics["environment/stats/required_count"], + cost=summary_metrics["performance/uptime"], + ) + carbs.observe(obs_in) + # Because wandb stages the artifacts we have to keep these files + # dangling around wasting good disk space. + # carbs.save_to_file(hash(tuple(finished)) + ".pt", upload_to_wandb=True) + elif run["state"] == RunState.pending: + print(f"PENDING RUN FOUND {run['name']}") + sweep.print_status() + else: + suggestion = carbs.suggest() @app.command() diff --git a/sweep-config.yaml b/sweep-config.yaml index 9b307d7..372a552 100644 --- a/sweep-config.yaml +++ b/sweep-config.yaml @@ -17,27 +17,27 @@ train: carbs.utils.LogSpace: min: 1.0e-5 max: 1.0e-3 - scale: .05 + scale: 1.0e-4 gamma: carbs.utils.LogitSpace: min: .5 max: 1.0 - scale: .05 + scale: .0005 gae_lambda: carbs.utils.LogitSpace: min: .5 max: 1.0 - scale: .05 + scale: .01 ent_coef: carbs.utils.LogSpace: min: 1.0e-5 max: 1.0e-1 - scale: .05 + scale: .005 vf_ent_coef: carbs.utils.LogSpace: min: 1.0e-5 max: 1.0e-1 - scale: .05 + scale: 1e-4 rewards: baseline.ObjectRewardRequiredEventsMapIds: @@ -65,7 +65,7 @@ rewards: hm_count: carbs.utils.LogSpace: min: 1.0e-3 - max: 10.0 + max: 15.0 scale: 2.0 level: carbs.utils.LogSpace: @@ -81,7 +81,7 @@ rewards: carbs.utils.LogSpace: min: 1.0e-7 max: 1.0 - scale: 1e-3 + scale: 5e-5 explore_signs: carbs.utils.LogSpace: min: 1.0e-4