Skip to content

Commit

Permalink
Just rely on local saving
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Nov 10, 2024
1 parent e452cd7 commit 15b3a78
Showing 1 changed file with 69 additions and 35 deletions.
104 changes: 69 additions & 35 deletions pokemonred_puffer/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import multiprocessing as mp
import os
import pprint
import re
from typing import Annotated

import carbs.utils
Expand Down Expand Up @@ -128,51 +129,84 @@ def launch_sweep(
f"< {param.space.max}"
)

config = CARBSParams(
better_direction_sign=1,
is_wandb_logging_enabled=True,
wandb_params=WandbLoggingParams(
project_name="Pokemon",
run_id=sweep_id,
root_dir="carbs",
),
resample_frequency=5,
num_random_samples=len(params),
)
if sweep_id:
config = CARBSParams(
better_direction_sign=1,
is_wandb_logging_enabled=True,
wandb_params=WandbLoggingParams(
project_name="Pokemon",
run_id=sweep_id,
run_name=sweep_name,
),
resample_frequency=5,
num_random_samples=len(params),
checkpoint_dir=f"carbs/checkpoints/{sweep_id}",
)

carbs = CARBS(config=config, params=params)
if not dry_run:
if sweep_id:
# runname = entity/project/run_id
carbs.warm_start_from_wandb(
run_name=f"{base_config.wandb.entity}/{base_config.wandb.project}/{sweep_id}",
carbs = CARBS(config=config, params=params)
# for wandb
# runname = entity/project/run_id
# find most recent file in checkpoint dir
experiment_dir = f"carbs/checkpoints/{sweep_id}/{sweep_name}"
saves = [
save_filename
for save_filename in os.listdir(experiment_dir)
if re.match(r"carbs_\d+obs.pt", save_filename)
]

if saves:
# sort by the middle int and take the highest value
# dont need split, could also use a regex group
save_filename = sorted(
saves, key=lambda x: int(x.replace("carbs_", "").replace("obs.pt", ""))
)[0]
carbs.warm_start(
filename=os.path.join(experiment_dir, save_filename),
is_prior_observation_valid=True,
)
else:
sweep_id = wandb.sweep(
sweep={
"name": sweep_name,
"controller": {"type": "local"},
"parameters": {
p.name: {"min": p.space.min, "max": p.space.max} for p in params
},
"metric": {
"name": "environment/stats/required_count",
"goal": "maximize",
"goal_value": 100,
},
"command": ["${args_json}"],

if not sweep_id and not dry_run:
sweep_id = wandb.sweep(
sweep={
"name": sweep_name,
"controller": {"type": "local"},
"parameters": {p.name: {"min": p.space.min, "max": p.space.max} for p in params},
"metric": {
"name": "environment/stats/required_count",
"goal": "maximize",
},
entity=base_config.wandb.entity,
project=base_config.wandb.project,
)
"command": ["${args_json}"],
},
entity=base_config.wandb.entity,
project=base_config.wandb.project,
)
config = CARBSParams(
better_direction_sign=1,
is_wandb_logging_enabled=True,
wandb_params=WandbLoggingParams(
project_name="Pokemon",
run_id=sweep_id,
run_name=sweep_name,
),
resample_frequency=5,
num_random_samples=len(params),
checkpoint_dir=f"carbs/checkpoints/{sweep_id}",
)
carbs = CARBS(config=config, params=params)
os.makedirs(os.path.join(config.checkpoint_dir, carbs.experiment_name), exist_ok=True)

carbs._autosave()

pprint.pprint(params)
if dry_run:
carbs.suggest()
return

sweep = wandb.controller(sweep_id)
sweep = wandb.controller(
sweep_id_or_config=sweep_id,
entity=base_config.wandb.entity,
project=base_config.wandb.project,
)

console.print(f"Beginning sweep with id {sweep_id}", style="bold")
console.print(
Expand Down

0 comments on commit 15b3a78

Please sign in to comment.