diff --git a/config.yaml b/config.yaml index d6bfdce..9b0b850 100644 --- a/config.yaml +++ b/config.yaml @@ -31,7 +31,7 @@ debug: minibatch_size: 4 batch_rows: 4 bptt_horizon: 2 - total_timesteps: 100_000_000 + total_timesteps: 10 save_checkpoint: True checkpoint_interval: 4 save_overlay: True diff --git a/pokemonred_puffer/sweep.py b/pokemonred_puffer/sweep.py index 45557a3..ec803c0 100644 --- a/pokemonred_puffer/sweep.py +++ b/pokemonred_puffer/sweep.py @@ -1,21 +1,15 @@ -import json import math -from pathlib import Path -from typing import Annotated, Any +from typing import Annotated import carbs.utils -import sweeps import typer -import yaml from carbs import ( - CARBS, - CARBSParams, - ObservationInParam, Param, ParamDictType, ParamType, - WandbLoggingParams, ) +from omegaconf import DictConfig, OmegaConf +from wandb_carbs import create_sweep import wandb from pokemonred_puffer import train @@ -23,96 +17,91 @@ app = typer.Typer(pretty_exceptions_enable=False) -def sweep_config_to_params(sweep_cfg: dict[str, Any], prefix: str = "") -> list[Param]: +def sweep_config_to_params(sweep_config: DictConfig, prefix: str = "") -> list[Param]: res = [] - for k, v in sweep_cfg.items(): + for k, v in sweep_config.items(): # A little hacky. Maybe I should not make this all config based if k.startswith("carbs.utils"): param_class = getattr(carbs.utils, k.split(".")[-1]) res += [ Param( - prefix.removesuffix("/").removeprefix("/"), + prefix.removesuffix("-").removeprefix("-"), param_class(**v), (v["max"] + v["min"]) // 2 if v.get("is_integer", False) else math.sqrt(v["max"] ** 2 + v["min"] ** 2), ) ] - elif isinstance(v, dict): - res += sweep_config_to_params(v, prefix=prefix + "/" + k) + elif isinstance(v, DictConfig): + res += sweep_config_to_params(v, prefix=prefix + "-" + k) + else: + print(type(v)) return res -def update_base_config_by_key( - base_cfg: dict[str, Any], key: str, value: ParamType -) -> dict[str, Any]: - new_cfg = base_cfg.copy() - keys = key.split("/", 1) +def update_base_config_by_key(base_config: DictConfig, key: str, value: ParamType) -> DictConfig: + new_config = base_config.copy() + keys = key.split("-", 1) if len(keys) == 1: - new_cfg[keys[0]] = value + new_config[keys[0]] = value else: - new_cfg[keys[0]] = update_base_config_by_key(new_cfg[keys[0]], keys[1], value) - return new_cfg + new_config[keys[0]] = update_base_config_by_key(new_config[keys[0]], keys[1], value) + return new_config -def update_base_config(base_cfg: dict[str, Any], suggestion: ParamDictType) -> dict[str, Any]: - new_cfg = base_cfg.copy() +def update_base_config(base_config: DictConfig, suggestion: ParamDictType) -> DictConfig: + new_config = base_config.copy() for k, v in suggestion.items(): - new_cfg = update_base_config_by_key(new_cfg, k, v) - return new_cfg + new_config = update_base_config_by_key(new_config, k, v) + return new_config @app.command() -def launch_controller( - base_config: Annotated[Path, typer.Option(help="Base configuration")] = Path("config.yaml"), +def launch_sweep( + base_config: Annotated[ + DictConfig, typer.Option(help="Base configuration", parser=OmegaConf.load) + ] = "config.yaml", sweep_config: Annotated[ - Path, typer.Option(help="CARBS sweep config. settings must match base config.") - ] = Path("sweep-config.yaml"), + DictConfig, + typer.Option( + help="CARBS sweep config. settings must match base config.", parser=OmegaConf.load + ), + ] = "sweep-config.yaml", + sweep_name: Annotated[str, typer.Option(help="Sweep name")] = "PokeSweep", ): - with open(base_config) as f: - base_cfg = yaml.safe_load(f) - with open(sweep_config) as f: - sweep_cfg = yaml.safe_load(f) - config = CARBSParams( - better_direction_sign=-1, - is_wandb_logging_enabled=True, - wandb_params=WandbLoggingParams(project_name="Pokemon", run_name="Pokemon"), - ) - params = sweep_config_to_params(sweep_cfg) + params = sweep_config_to_params(sweep_config) import pprint pprint.pprint(params) - carbs = CARBS(config=config, params=params) - sweep_id = wandb.sweep( - sweep={"controller": {"type": "local"}, "parameters": {}}, + sweep_id = create_sweep( + sweep_name=sweep_name, + wandb_entity=base_config.wandb.entity, + wandb_project=base_config.wandb.project, + carb_params=params, ) - sweep = wandb.controller(sweep_id) + print(f"Beginning sweep with id {sweep_id}") - print( - f"On all nodes please run wandb with wandb.agent(sweep_id={sweep_id}, " - "function=)" - ) - while not sweep.done(): - sweep_obj = sweep._sweep_object_read_from_backend() - if sweep_obj["runs"]: - print(sweep_obj["runs"]) - breakpoint() - obs_in = ObservationInParam(...) # parsed from sweep_obj. Need to figure this out - carbs.observe(obs_in) - suggestion = carbs.suggest() - new_cfg = update_base_config(base_cfg, suggestion.suggestion) - run = sweeps.SweepRun(config=new_cfg) - sweep.schedule(run) - sweep.print_status() + print(f"On all nodes please run python -m pokemonred_puffer.sweep launch-agent {sweep_id}") @app.command() -def launch_agent(sweep_id: str): +def launch_agent( + sweep_id: str, + base_config: Annotated[ + DictConfig, typer.Option(help="Base configuration", parser=OmegaConf.load) + ] = "config.yaml", +): + def _fn(): + new_config = update_base_config(base_config) + print(new_config) + train.train(config=new_config) + wandb.agent( - sweep_id, - lambda params: train.train(json.dumps(yaml=params, track=True)), - entity="Pokemon", - project="Pokemon", + sweep_id=sweep_id, + entity=base_config.wandb.entity, + project=base_config.wandb.project, + function=_fn, + count=999999, ) diff --git a/pokemonred_puffer/train.py b/pokemonred_puffer/train.py index 31b6511..8934d34 100644 --- a/pokemonred_puffer/train.py +++ b/pokemonred_puffer/train.py @@ -1,6 +1,5 @@ import functools import importlib -import json import os import uuid from contextlib import contextmanager @@ -49,9 +48,7 @@ def make_policy(env: RedGymEnv, policy_name: str, config: DictConfig) -> nn.Modu return policy.to(config.train.device) -def load_from_config(yaml: Path, debug: bool) -> DictConfig: - config: DictConfig = OmegaConf.load(yaml) - +def load_from_config(config: DictConfig, debug: bool) -> DictConfig: default_keys = ["env", "train", "policies", "rewards", "wrappers", "wandb"] defaults = OmegaConf.create({key: config.get(key, {}) for key in default_keys}) @@ -147,18 +144,13 @@ def init_wandb( def setup( - yaml: Path | str, + config: DictConfig, debug: bool, wrappers_name: str, reward_name: str, rom_path: Path, track: bool, ) -> tuple[DictConfig, Callable[[DictConfig, DictConfig], pufferlib.emulation.GymnasiumPufferEnv]]: - possible_dictconfig = json.load(yaml) - if isinstance(possible_dictconfig, dict): - config = OmegaConf.create(possible_dictconfig) - else: - config = load_from_config(yaml, debug) config.train.exp_id = f"pokemon-red-{str(uuid.uuid4())[:8]}" config.env.gb_path = rom_path config.track = track @@ -172,7 +164,9 @@ def setup( @app.command() def evaluate( - yaml: Annotated[Path | str, typer.Option(help="Configuration file to use")] = "config.yaml", + config: Annotated[ + DictConfig, typer.Option(help="Base configuration", parser=OmegaConf.load) + ] = "config.yaml", checkpoint_path: Path | None = None, policy_name: Annotated[ str, @@ -201,7 +195,7 @@ def evaluate( rom_path: Path = "red.gb", ): config, env_creator = setup( - yaml=yaml, + config=config, debug=False, wrappers_name=wrappers_name, reward_name=reward_name, @@ -227,7 +221,9 @@ def evaluate( @app.command() def autotune( - yaml: Annotated[Path, typer.Option(help="Configuration file to use")] = "config.yaml", + config: Annotated[ + DictConfig, typer.Option(help="Base configuration", parser=OmegaConf.load) + ] = "config.yaml", policy_name: Annotated[ str, typer.Option( @@ -255,7 +251,7 @@ def autotune( rom_path: Path = "red.gb", ): config, env_creator = setup( - yaml=yaml, + config=config, debug=False, wrappers_name=wrappers_name, reward_name=reward_name, @@ -275,7 +271,9 @@ def autotune( @app.command() def train( - yaml: Annotated[Path | str, typer.Option(help="Configuration file to use")] = "config.yaml", + config: Annotated[ + DictConfig, typer.Option(help="Base configuration", parser=OmegaConf.load) + ] = "config.yaml", policy_name: Annotated[ str, typer.Option( @@ -309,7 +307,7 @@ def train( ] = "multiprocessing", ): config, env_creator = setup( - yaml=yaml, + config=config, debug=debug, wrappers_name=wrappers_name, reward_name=reward_name, diff --git a/pyproject.toml b/pyproject.toml index d5d0362..61759de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,6 @@ classifiers = [ ] dependencies = [ "einops", - "mediapy", "numba", "numpy", "omegaconf", @@ -24,7 +23,7 @@ dependencies = [ "torch>=2.4", "torchvision", "typer", - "wandb[sweeps]", + "wandb", "websockets" ] @@ -42,7 +41,8 @@ dev = [ ] sweep = [ "carbs @ git+https://github.com/imbue-ai/carbs", - "sweeps" + "sweeps", + "wandb-carbs" ] [tool.distutils.bdist_wheel]