Skip to content

Commit

Permalink
fix wandb saving
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Nov 10, 2024
1 parent 61cc44d commit 7272fec
Showing 1 changed file with 115 additions and 111 deletions.
226 changes: 115 additions & 111 deletions pokemonred_puffer/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,27 @@
import math
import multiprocessing as mp
import os
import pprint
from typing import Annotated

import carbs.utils
import sweeps
from sweeps import RunState
import typer
from carbs import (
CARBS,
CARBSParams,
LogitSpace,
ObservationInParam,
Param,
ParamDictType,
ParamType,
CARBSParams,
WandbLoggingParams,
ObservationInParam,
)
from omegaconf import DictConfig, OmegaConf
from rich.console import Console
import wandb
from sweeps import RunState

import wandb
from pokemonred_puffer import train
from pokemonred_puffer.environment import RedGymEnv

Expand Down Expand Up @@ -95,49 +96,55 @@ def launch_sweep(
dry_run: Annotated[bool, typer.Option(help="Attempts to start CARBS, but not wandb")] = False,
):
console = Console()
if not sweep_id:
params = sweep_config_to_params(base_config, sweep_config)
for param in params:
print(f"Checking param: {param}")
if isinstance(param.space, LogitSpace):
assert (
0.0
<= param.space.min
< param.search_center - param.space.scale
< param.search_center + param.space.scale
< param.space.max
<= 1.0
), (
"0.0 "
f"<= {param.space.min} "
f"< {param.search_center} - {param.space.scale} "
f"< {param.search_center} + {param.space.scale} "
f"< {param.space.max} "
f"<= 1.0"
)
else:
assert (
param.space.min
< param.search_center - param.space.scale
< param.search_center + param.space.scale
< param.space.max
), (
f"{param.space.min} "
f"< {param.search_center} - {param.space.scale} "
f"< {param.search_center} + {param.space.scale} "
f"< {param.space.max}"
)
params = sweep_config_to_params(base_config, sweep_config)
for param in params:
print(f"Checking param: {param}")
if isinstance(param.space, LogitSpace):
assert (
0.0
<= param.space.min
< param.search_center - param.space.scale
< param.search_center + param.space.scale
< param.space.max
<= 1.0
), (
"0.0 "
f"<= {param.space.min} "
f"< {param.search_center} - {param.space.scale} "
f"< {param.search_center} + {param.space.scale} "
f"< {param.space.max} "
f"<= 1.0"
)
else:
assert (
param.space.min
< param.search_center - param.space.scale
< param.search_center + param.space.scale
< param.space.max
), (
f"{param.space.min} "
f"< {param.search_center} - {param.space.scale} "
f"< {param.search_center} + {param.space.scale} "
f"< {param.space.max}"
)

config = CARBSParams(
better_direction_sign=1,
is_wandb_logging_enabled=False,
wandb_params=WandbLoggingParams(project_name="Pokemon", run_name=sweep_id),
resample_frequency=5,
num_random_samples=len(params),
)
config = CARBSParams(
better_direction_sign=1,
is_wandb_logging_enabled=True,
wandb_params=WandbLoggingParams(
project_name="Pokemon",
run_id=sweep_id,
root_dir="carbs",
),
resample_frequency=5,
num_random_samples=len(params),
)

carbs = CARBS(config=config, params=params)
if not dry_run:
carbs = CARBS(config=config, params=params)
if not dry_run:
if sweep_id:
carbs.warm_start_from_wandb(run_name=sweep_id, is_prior_observation_valid=True)
else:
sweep_id = wandb.sweep(
sweep={
"name": sweep_name,
Expand All @@ -155,76 +162,73 @@ def launch_sweep(
entity=base_config.wandb.entity,
project=base_config.wandb.project,
)
else:
carbs = CARBS.warm_start_from_wandb(run_name=sweep_id, is_prior_observation_valid=True)

import pprint

pprint.pprint(params)
if not dry_run:
sweep = wandb.controller(sweep_id)
if dry_run:
carbs.suggest()
return

console.print(f"Beginning sweep with id {sweep_id}", style="bold")
console.print(
f"On all nodes please run python -m pokemonred_puffer.sweep launch-agent {sweep_id}",
style="bold",
)
sweep = wandb.controller(sweep_id)

finished = []
while not sweep.done():
# Taken from sweep.schedule. Limits runs to only one at a time.
# Only one run will be scheduled at a time
sweep._step()
if not (sweep._controller and sweep._controller.get("schedule")):
suggestion = carbs.suggest()
run = sweeps.SweepRun(
config={k: {"value": v} for k, v in suggestion.suggestion.items()}
)
sweep.schedule(run)
# without this nothing updates...
sweep_obj = sweep._sweep_obj
if runs := sweep_obj["runs"]:
for run in runs:
if run["state"] == RunState.running.value:
pass
elif (
run["state"]
in [
RunState.failed.value,
RunState.finished.value,
RunState.crashed.value,
]
and run["name"] not in finished
):
finished.append(run["name"])
if summaryMetrics_json := run.get("summaryMetrics", None):
summary_metrics = json.loads(summaryMetrics_json)
if (
"environment/stats/required_count" in summary_metrics
and "performance/uptime" in summary_metrics
# Only count agents that have run more than 1M steps
and "Overview/agent_steps" in summary_metrics
and summary_metrics["Overview/agent_steps"] > 1e6
):
obs_in = ObservationInParam(
input={
k: v["value"]
for k, v in json.loads(run["config"]).items()
if k != "wandb_version"
},
# TODO: try out other stats like required count
output=summary_metrics["environment/stats/required_count"],
cost=summary_metrics["performance/uptime"],
)
carbs.observe(obs_in)
# Because wandb stages the artifacts we have to keep these files
# dangling around wasting good disk space.
# carbs.save_to_file(hash(tuple(finished)) + ".pt", upload_to_wandb=True)
elif run["state"] == RunState.pending:
print(f"PENDING RUN FOUND {run['name']}")
sweep.print_status()
else:
suggestion = carbs.suggest()
console.print(f"Beginning sweep with id {sweep_id}", style="bold")
console.print(
f"On all nodes please run python -m pokemonred_puffer.sweep launch-agent {sweep_id}",
style="bold",
)

finished = []
while not sweep.done():
# Taken from sweep.schedule. Limits runs to only one at a time.
# Only one run will be scheduled at a time
sweep._step()
if not (sweep._controller and sweep._controller.get("schedule")):
suggestion = carbs.suggest()
run = sweeps.SweepRun(
config={k: {"value": v} for k, v in suggestion.suggestion.items()}
)
sweep.schedule(run)
# without this nothing updates...
sweep_obj = sweep._sweep_obj
if runs := sweep_obj["runs"]:
for run in runs:
if run["state"] == RunState.running.value:
pass
elif (
run["state"]
in [
RunState.failed.value,
RunState.finished.value,
RunState.crashed.value,
]
and run["name"] not in finished
):
finished.append(run["name"])
if summaryMetrics_json := run.get("summaryMetrics", None):
summary_metrics = json.loads(summaryMetrics_json)
if (
"environment/stats/required_count" in summary_metrics
and "performance/uptime" in summary_metrics
# Only count agents that have run more than 1M steps
and "Overview/agent_steps" in summary_metrics
and summary_metrics["Overview/agent_steps"] > 1e6
):
obs_in = ObservationInParam(
input={
k: v["value"]
for k, v in json.loads(run["config"]).items()
if k != "wandb_version"
},
# TODO: try out other stats like required count
output=summary_metrics["environment/stats/required_count"],
cost=summary_metrics["performance/uptime"],
)
carbs.observe(obs_in)
# Because wandb stages the artifacts we have to keep these files
# dangling around wasting good disk space.
# carbs.save_to_file(hash(tuple(finished)) + ".pt", upload_to_wandb=True)
elif run["state"] == RunState.pending:
print(f"PENDING RUN FOUND {run['name']}")
sweep.print_status()


@app.command()
Expand Down

0 comments on commit 7272fec

Please sign in to comment.