Skip to content

Commit

Permalink
neater sweeper?
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Oct 12, 2024
1 parent fca6ff9 commit 89ba029
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 85 deletions.
2 changes: 1 addition & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ debug:
minibatch_size: 4
batch_rows: 4
bptt_horizon: 2
total_timesteps: 100_000_000
total_timesteps: 10
save_checkpoint: True
checkpoint_interval: 4
save_overlay: True
Expand Down
119 changes: 54 additions & 65 deletions pokemonred_puffer/sweep.py
Original file line number Diff line number Diff line change
@@ -1,118 +1,107 @@
import json
import math
from pathlib import Path
from typing import Annotated, Any
from typing import Annotated

import carbs.utils
import sweeps
import typer
import yaml
from carbs import (
CARBS,
CARBSParams,
ObservationInParam,
Param,
ParamDictType,
ParamType,
WandbLoggingParams,
)
from omegaconf import DictConfig, OmegaConf
from wandb_carbs import create_sweep

import wandb
from pokemonred_puffer import train

app = typer.Typer(pretty_exceptions_enable=False)


def sweep_config_to_params(sweep_cfg: dict[str, Any], prefix: str = "") -> list[Param]:
def sweep_config_to_params(sweep_config: DictConfig, prefix: str = "") -> list[Param]:
res = []
for k, v in sweep_cfg.items():
for k, v in sweep_config.items():
# A little hacky. Maybe I should not make this all config based
if k.startswith("carbs.utils"):
param_class = getattr(carbs.utils, k.split(".")[-1])
res += [
Param(
prefix.removesuffix("/").removeprefix("/"),
prefix.removesuffix("-").removeprefix("-"),
param_class(**v),
(v["max"] + v["min"]) // 2
if v.get("is_integer", False)
else math.sqrt(v["max"] ** 2 + v["min"] ** 2),
)
]
elif isinstance(v, dict):
res += sweep_config_to_params(v, prefix=prefix + "/" + k)
elif isinstance(v, DictConfig):
res += sweep_config_to_params(v, prefix=prefix + "-" + k)
else:
print(type(v))
return res


def update_base_config_by_key(
base_cfg: dict[str, Any], key: str, value: ParamType
) -> dict[str, Any]:
new_cfg = base_cfg.copy()
keys = key.split("/", 1)
def update_base_config_by_key(base_config: DictConfig, key: str, value: ParamType) -> DictConfig:
new_config = base_config.copy()
keys = key.split("-", 1)
if len(keys) == 1:
new_cfg[keys[0]] = value
new_config[keys[0]] = value
else:
new_cfg[keys[0]] = update_base_config_by_key(new_cfg[keys[0]], keys[1], value)
return new_cfg
new_config[keys[0]] = update_base_config_by_key(new_config[keys[0]], keys[1], value)
return new_config


def update_base_config(base_cfg: dict[str, Any], suggestion: ParamDictType) -> dict[str, Any]:
new_cfg = base_cfg.copy()
def update_base_config(base_config: DictConfig, suggestion: ParamDictType) -> DictConfig:
new_config = base_config.copy()
for k, v in suggestion.items():
new_cfg = update_base_config_by_key(new_cfg, k, v)
return new_cfg
new_config = update_base_config_by_key(new_config, k, v)
return new_config


@app.command()
def launch_controller(
base_config: Annotated[Path, typer.Option(help="Base configuration")] = Path("config.yaml"),
def launch_sweep(
base_config: Annotated[
DictConfig, typer.Option(help="Base configuration", parser=OmegaConf.load)
] = "config.yaml",
sweep_config: Annotated[
Path, typer.Option(help="CARBS sweep config. settings must match base config.")
] = Path("sweep-config.yaml"),
DictConfig,
typer.Option(
help="CARBS sweep config. settings must match base config.", parser=OmegaConf.load
),
] = "sweep-config.yaml",
sweep_name: Annotated[str, typer.Option(help="Sweep name")] = "PokeSweep",
):
with open(base_config) as f:
base_cfg = yaml.safe_load(f)
with open(sweep_config) as f:
sweep_cfg = yaml.safe_load(f)
config = CARBSParams(
better_direction_sign=-1,
is_wandb_logging_enabled=True,
wandb_params=WandbLoggingParams(project_name="Pokemon", run_name="Pokemon"),
)
params = sweep_config_to_params(sweep_cfg)
params = sweep_config_to_params(sweep_config)
import pprint

pprint.pprint(params)
carbs = CARBS(config=config, params=params)
sweep_id = wandb.sweep(
sweep={"controller": {"type": "local"}, "parameters": {}},
sweep_id = create_sweep(
sweep_name=sweep_name,
wandb_entity=base_config.wandb.entity,
wandb_project=base_config.wandb.project,
carb_params=params,
)
sweep = wandb.controller(sweep_id)

print(f"Beginning sweep with id {sweep_id}")
print(
f"On all nodes please run wandb with wandb.agent(sweep_id={sweep_id}, "
"function=<your-function>)"
)
while not sweep.done():
sweep_obj = sweep._sweep_object_read_from_backend()
if sweep_obj["runs"]:
print(sweep_obj["runs"])
breakpoint()
obs_in = ObservationInParam(...) # parsed from sweep_obj. Need to figure this out
carbs.observe(obs_in)
suggestion = carbs.suggest()
new_cfg = update_base_config(base_cfg, suggestion.suggestion)
run = sweeps.SweepRun(config=new_cfg)
sweep.schedule(run)
sweep.print_status()
print(f"On all nodes please run python -m pokemonred_puffer.sweep launch-agent {sweep_id}")


@app.command()
def launch_agent(sweep_id: str):
def launch_agent(
sweep_id: str,
base_config: Annotated[
DictConfig, typer.Option(help="Base configuration", parser=OmegaConf.load)
] = "config.yaml",
):
def _fn():
new_config = update_base_config(base_config)
print(new_config)
train.train(config=new_config)

wandb.agent(
sweep_id,
lambda params: train.train(json.dumps(yaml=params, track=True)),
entity="Pokemon",
project="Pokemon",
sweep_id=sweep_id,
entity=base_config.wandb.entity,
project=base_config.wandb.project,
function=_fn,
count=999999,
)


Expand Down
30 changes: 14 additions & 16 deletions pokemonred_puffer/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import functools
import importlib
import json
import os
import uuid
from contextlib import contextmanager
Expand Down Expand Up @@ -49,9 +48,7 @@ def make_policy(env: RedGymEnv, policy_name: str, config: DictConfig) -> nn.Modu
return policy.to(config.train.device)


def load_from_config(yaml: Path, debug: bool) -> DictConfig:
config: DictConfig = OmegaConf.load(yaml)

def load_from_config(config: DictConfig, debug: bool) -> DictConfig:
default_keys = ["env", "train", "policies", "rewards", "wrappers", "wandb"]
defaults = OmegaConf.create({key: config.get(key, {}) for key in default_keys})

Expand Down Expand Up @@ -147,18 +144,13 @@ def init_wandb(


def setup(
yaml: Path | str,
config: DictConfig,
debug: bool,
wrappers_name: str,
reward_name: str,
rom_path: Path,
track: bool,
) -> tuple[DictConfig, Callable[[DictConfig, DictConfig], pufferlib.emulation.GymnasiumPufferEnv]]:
possible_dictconfig = json.load(yaml)
if isinstance(possible_dictconfig, dict):
config = OmegaConf.create(possible_dictconfig)
else:
config = load_from_config(yaml, debug)
config.train.exp_id = f"pokemon-red-{str(uuid.uuid4())[:8]}"
config.env.gb_path = rom_path
config.track = track
Expand All @@ -172,7 +164,9 @@ def setup(

@app.command()
def evaluate(
yaml: Annotated[Path | str, typer.Option(help="Configuration file to use")] = "config.yaml",
config: Annotated[
DictConfig, typer.Option(help="Base configuration", parser=OmegaConf.load)
] = "config.yaml",
checkpoint_path: Path | None = None,
policy_name: Annotated[
str,
Expand Down Expand Up @@ -201,7 +195,7 @@ def evaluate(
rom_path: Path = "red.gb",
):
config, env_creator = setup(
yaml=yaml,
config=config,
debug=False,
wrappers_name=wrappers_name,
reward_name=reward_name,
Expand All @@ -227,7 +221,9 @@ def evaluate(

@app.command()
def autotune(
yaml: Annotated[Path, typer.Option(help="Configuration file to use")] = "config.yaml",
config: Annotated[
DictConfig, typer.Option(help="Base configuration", parser=OmegaConf.load)
] = "config.yaml",
policy_name: Annotated[
str,
typer.Option(
Expand Down Expand Up @@ -255,7 +251,7 @@ def autotune(
rom_path: Path = "red.gb",
):
config, env_creator = setup(
yaml=yaml,
config=config,
debug=False,
wrappers_name=wrappers_name,
reward_name=reward_name,
Expand All @@ -275,7 +271,9 @@ def autotune(

@app.command()
def train(
yaml: Annotated[Path | str, typer.Option(help="Configuration file to use")] = "config.yaml",
config: Annotated[
DictConfig, typer.Option(help="Base configuration", parser=OmegaConf.load)
] = "config.yaml",
policy_name: Annotated[
str,
typer.Option(
Expand Down Expand Up @@ -309,7 +307,7 @@ def train(
] = "multiprocessing",
):
config, env_creator = setup(
yaml=yaml,
config=config,
debug=debug,
wrappers_name=wrappers_name,
reward_name=reward_name,
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ classifiers = [
]
dependencies = [
"einops",
"mediapy",
"numba",
"numpy",
"omegaconf",
Expand All @@ -24,7 +23,7 @@ dependencies = [
"torch>=2.4",
"torchvision",
"typer",
"wandb[sweeps]",
"wandb",
"websockets"
]

Expand All @@ -42,7 +41,8 @@ dev = [
]
sweep = [
"carbs @ git+https://github.com/imbue-ai/carbs",
"sweeps"
"sweeps",
"wandb-carbs"
]

[tool.distutils.bdist_wheel]
Expand Down

0 comments on commit 89ba029

Please sign in to comment.