diff --git a/config.yaml b/config.yaml index 6940567..5d2874f 100644 --- a/config.yaml +++ b/config.yaml @@ -111,6 +111,7 @@ train: pool_kernel: [0] load_optimizer_state: False use_rnn: True + async_wrapper: False # swarm_frequency: 500 # swarm_keep_pct: .8 diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index c126c9e..44d05d4 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -213,7 +213,8 @@ def evaluate(self): # their states to the bottom 90%. # we do this here so the environment can remain "pure" if ( - hasattr(self.config, "swarm_frequency") + self.config.async_wrapper + and hasattr(self.config, "swarm_frequency") and hasattr(self.config, "swarm_keep_pct") and self.epoch % self.config.swarm_frequency == 0 and "reward/event" in self.infos diff --git a/pokemonred_puffer/train.py b/pokemonred_puffer/train.py index a37f6e8..0fe40f6 100644 --- a/pokemonred_puffer/train.py +++ b/pokemonred_puffer/train.py @@ -284,7 +284,7 @@ def train( args = update_args(args) args.train.exp_id = f"pokemon-red-{str(uuid.uuid4())[:8]}" - async_wrapper = args.mode == "train" + async_wrapper = args.train.async_wrapper env_creator = setup_agent(args.wrappers[args.wrappers_name], args.reward_name, async_wrapper) wandb_client = None