Skip to content

Commit

Permalink
add some debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Oct 29, 2024
1 parent ffc8772 commit 6c8e13f
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 78 deletions.
164 changes: 93 additions & 71 deletions pokemonred_puffer/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,25 @@ def launch_sweep(
"N.B. The sweep and base config MUST BE THE SAME"
),
] = None,
dry_run: Annotated[bool, typer.Option(help="Attempts to start CARBS, but not wandb")] = False,
):
console = Console()
if not sweep_id:
params = sweep_config_to_params(base_config, sweep_config)
for param in params:
print(f"Checking param: {param}")
assert (
param.space.min
< param.search_center - param.space.scale
< param.search_center + param.space.scale
< param.space.max
), (
f"{param.space.min} < "
f"{param.search_center} - {param.space.scale} "
f"< {param.search_center} + {param.space.scale} "
f"< {param.space.max}"
)

config = CARBSParams(
better_direction_sign=1,
is_wandb_logging_enabled=False,
Expand All @@ -104,84 +119,91 @@ def launch_sweep(
)

carbs = CARBS(config=config, params=params)
sweep_id = wandb.sweep(
sweep={
"name": sweep_name,
"controller": {"type": "local"},
"parameters": {p.name: {"min": p.space.min, "max": p.space.max} for p in params},
"metric": {
"name": "environment/stats/required_count",
"goal": "maximize",
"goal_value": 100,
if not dry_run:
sweep_id = wandb.sweep(
sweep={
"name": sweep_name,
"controller": {"type": "local"},
"parameters": {
p.name: {"min": p.space.min, "max": p.space.max} for p in params
},
"metric": {
"name": "environment/stats/required_count",
"goal": "maximize",
"goal_value": 100,
},
"command": ["${args_json}"],
},
"command": ["${args_json}"],
},
entity=base_config.wandb.entity,
project=base_config.wandb.project,
)
entity=base_config.wandb.entity,
project=base_config.wandb.project,
)
else:
carbs = CARBS.warm_start_from_wandb(run_name=sweep_id, is_prior_observation_valid=True)

import pprint

pprint.pprint(params)
sweep = wandb.controller(sweep_id)

console.print(f"Beginning sweep with id {sweep_id}", style="bold")
console.print(
f"On all nodes please run python -m pokemonred_puffer.sweep launch-agent {sweep_id}",
style="bold",
)
finished = []
while not sweep.done():
# Taken from sweep.schedule. Limits runs to only one at a time.
# Only one run will be scheduled at a time
sweep._step()
if not (sweep._controller and sweep._controller.get("schedule")):
suggestion = carbs.suggest()
run = sweeps.SweepRun(
config={k: {"value": v} for k, v in suggestion.suggestion.items()}
)
sweep.schedule(run)
# without this nothing updates...
sweep_obj = sweep._sweep_obj
if runs := sweep_obj["runs"]:
for run in runs:
if run["state"] == RunState.running.value:
pass
elif (
run["state"]
in [
RunState.failed.value,
RunState.finished.value,
RunState.crashed.value,
]
and run["name"] not in finished
):
finished.append(run["name"])
if summaryMetrics_json := run.get("summaryMetrics", None):
summary_metrics = json.loads(summaryMetrics_json)
if (
"environment/stats/required_count" in summary_metrics
and "performance/uptime" in summary_metrics
):
obs_in = ObservationInParam(
input={
k: v["value"]
for k, v in json.loads(run["config"]).items()
if k != "wandb_version"
},
# TODO: try out other stats like required count
output=summary_metrics["environment/stats/required_count"],
cost=summary_metrics["performance/uptime"],
)
carbs.observe(obs_in)
# Because wandb stages the artifacts we have to keep these files
# dangling around wasting good disk space.
# carbs.save_to_file(hash(tuple(finished)) + ".pt", upload_to_wandb=True)
elif run["state"] == RunState.pending:
print(f"PENDING RUN FOUND {run['name']}")
sweep.print_status()
if not dry_run:
sweep = wandb.controller(sweep_id)

console.print(f"Beginning sweep with id {sweep_id}", style="bold")
console.print(
f"On all nodes please run python -m pokemonred_puffer.sweep launch-agent {sweep_id}",
style="bold",
)

finished = []
while not sweep.done():
# Taken from sweep.schedule. Limits runs to only one at a time.
# Only one run will be scheduled at a time
sweep._step()
if not (sweep._controller and sweep._controller.get("schedule")):
suggestion = carbs.suggest()
run = sweeps.SweepRun(
config={k: {"value": v} for k, v in suggestion.suggestion.items()}
)
sweep.schedule(run)
# without this nothing updates...
sweep_obj = sweep._sweep_obj
if runs := sweep_obj["runs"]:
for run in runs:
if run["state"] == RunState.running.value:
pass
elif (
run["state"]
in [
RunState.failed.value,
RunState.finished.value,
RunState.crashed.value,
]
and run["name"] not in finished
):
finished.append(run["name"])
if summaryMetrics_json := run.get("summaryMetrics", None):
summary_metrics = json.loads(summaryMetrics_json)
if (
"environment/stats/required_count" in summary_metrics
and "performance/uptime" in summary_metrics
):
obs_in = ObservationInParam(
input={
k: v["value"]
for k, v in json.loads(run["config"]).items()
if k != "wandb_version"
},
# TODO: try out other stats like required count
output=summary_metrics["environment/stats/required_count"],
cost=summary_metrics["performance/uptime"],
)
carbs.observe(obs_in)
# Because wandb stages the artifacts we have to keep these files
# dangling around wasting good disk space.
# carbs.save_to_file(hash(tuple(finished)) + ".pt", upload_to_wandb=True)
elif run["state"] == RunState.pending:
print(f"PENDING RUN FOUND {run['name']}")
sweep.print_status()
else:
suggestion = carbs.suggest()


@app.command()
Expand Down
14 changes: 7 additions & 7 deletions sweep-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,27 @@ train:
carbs.utils.LogSpace:
min: 1.0e-5
max: 1.0e-3
scale: .05
scale: 1.0e-4
gamma:
carbs.utils.LogitSpace:
min: .5
max: 1.0
scale: .05
scale: .0005
gae_lambda:
carbs.utils.LogitSpace:
min: .5
max: 1.0
scale: .05
scale: .01
ent_coef:
carbs.utils.LogSpace:
min: 1.0e-5
max: 1.0e-1
scale: .05
scale: .005
vf_ent_coef:
carbs.utils.LogSpace:
min: 1.0e-5
max: 1.0e-1
scale: .05
scale: 1e-4

rewards:
baseline.ObjectRewardRequiredEventsMapIds:
Expand Down Expand Up @@ -65,7 +65,7 @@ rewards:
hm_count:
carbs.utils.LogSpace:
min: 1.0e-3
max: 10.0
max: 15.0
scale: 2.0
level:
carbs.utils.LogSpace:
Expand All @@ -81,7 +81,7 @@ rewards:
carbs.utils.LogSpace:
min: 1.0e-7
max: 1.0
scale: 1e-3
scale: 5e-5
explore_signs:
carbs.utils.LogSpace:
min: 1.0e-4
Expand Down

0 comments on commit 6c8e13f

Please sign in to comment.