diff --git a/pokemonred_puffer/sweep.py b/pokemonred_puffer/sweep.py index 4602841..4b8371e 100644 --- a/pokemonred_puffer/sweep.py +++ b/pokemonred_puffer/sweep.py @@ -90,7 +90,6 @@ def launch_sweep( wandb_params=WandbLoggingParams(project_name="Pokemon", run_name="Pokemon"), ) params = sweep_config_to_params(base_config, sweep_config) - breakpoint() carbs = CARBS(config=config, params=params) sweep_id = wandb.sweep( @@ -132,9 +131,9 @@ def launch_sweep( elif ( run["state"] in [ - RunState.failed.value, + # RunState.failed.value, RunState.finished.value, - RunState.crashed.value, + # RunState.crashed.value, ] and run["name"] not in finished ): @@ -142,6 +141,7 @@ def launch_sweep( summary_metrics = json.loads(run["summaryMetrics"]) obs_in = ObservationInParam( input=json.loads(run["config"])["x"]["value"], + # TODO: try out other stats like required count output=summary_metrics["environment/stats/event"], cost=summary_metrics["performance/uptime"], ) diff --git a/pokemonred_puffer/train.py b/pokemonred_puffer/train.py index ff515c1..b604484 100644 --- a/pokemonred_puffer/train.py +++ b/pokemonred_puffer/train.py @@ -307,8 +307,8 @@ def train( Vectorization, typer.Option(help="Vectorization method") ] = "multiprocessing", ): - config.vectorization = vectorization config = load_from_config(config, debug) + config.vectorization = vectorization config, env_creator = setup( config=config, debug=debug,