diff --git a/pokemonred_puffer/sweep.py b/pokemonred_puffer/sweep.py index 465fbca..62814d8 100644 --- a/pokemonred_puffer/sweep.py +++ b/pokemonred_puffer/sweep.py @@ -10,6 +10,7 @@ import typer from carbs import ( CARBS, + LogitSpace, Param, ParamDictType, ParamType, @@ -98,17 +99,34 @@ def launch_sweep( params = sweep_config_to_params(base_config, sweep_config) for param in params: print(f"Checking param: {param}") - assert ( - param.space.min - < param.search_center - param.space.scale - < param.search_center + param.space.scale - < param.space.max - ), ( - f"{param.space.min} " - f"< {param.search_center} - {param.space.scale} " - f"< {param.search_center} + {param.space.scale} " - f"< {param.space.max}" - ) + if isinstance(param.space, LogitSpace): + assert ( + 0.0 + <= param.space.min + < param.search_center - param.space.scale + < param.search_center + param.space.scale + < param.space.max + <= 1.0 + ), ( + "0.0 " + f"<= {param.space.min} " + f"< {param.search_center} - {param.space.scale} " + f"< {param.search_center} + {param.space.scale} " + f"< {param.space.max} " + f"<= 1.0" + ) + else: + assert ( + param.space.min + < param.search_center - param.space.scale + < param.search_center + param.space.scale + < param.space.max + ), ( + f"{param.space.min} " + f"< {param.search_center} - {param.space.scale} " + f"< {param.search_center} + {param.space.scale} " + f"< {param.space.max}" + ) config = CARBSParams( better_direction_sign=1, diff --git a/sweep-config.yaml b/sweep-config.yaml index 372a552..1fc1574 100644 --- a/sweep-config.yaml +++ b/sweep-config.yaml @@ -4,7 +4,7 @@ env: min: 10240 max: 81920 is_integer: True - scale: 10000 + scale: 2500 train: total_timesteps: @@ -12,7 +12,7 @@ train: min: 500_000_000 max: 10_000_000_000 is_integer: True - scale: 100_000_000 + scale: 10_000_000 learning_rate: carbs.utils.LogSpace: min: 1.0e-5 @@ -20,12 +20,12 @@ train: scale: 1.0e-4 gamma: carbs.utils.LogitSpace: - min: .5 + min: .75 max: 1.0 scale: .0005 gae_lambda: carbs.utils.LogitSpace: - min: .5 + min: .75 max: 1.0 scale: .01 ent_coef: @@ -51,22 +51,22 @@ rewards: carbs.utils.LogSpace: min: 1.0e-3 max: 10.0 - scale: 2.0 + scale: 1.0 caught_pokemon: carbs.utils.LogSpace: min: 1.0e-3 max: 10.0 - scale: 2.0 + scale: 1.0 moves_obtained: carbs.utils.LogSpace: min: 1.0e-3 max: 10.0 - scale: 2.0 + scale: 1.0 hm_count: carbs.utils.LogSpace: min: 1.0e-3 max: 15.0 - scale: 2.0 + scale: 1.0 level: carbs.utils.LogSpace: min: 1.0e-5 @@ -91,12 +91,12 @@ rewards: carbs.utils.LogSpace: min: 1.0e-3 max: 10.0 - scale: 2.0 + scale: 1.0 required_item: carbs.utils.LogSpace: min: 1.0e-3 max: 10.0 - scale: 2.0 + scale: 1.0 useful_item: carbs.utils.LogSpace: min: 1.0e-3