From e4a528af996eac80e763fe0e10cf2f7e96a06097 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Mon, 8 Apr 2024 22:23:37 +0900 Subject: [PATCH] swarm every n updates --- config.yaml | 2 + pokemonred_puffer/cleanrl_puffer.py | 88 ++++++++++++++++++++++++++++- 2 files changed, 87 insertions(+), 3 deletions(-) diff --git a/config.yaml b/config.yaml index e71cb13..3809938 100644 --- a/config.yaml +++ b/config.yaml @@ -27,6 +27,8 @@ debug: env_pool: False log_frequency: 5000 load_optimizer_state: False + swarm_frequency: 1 + swarm_pct: 10 env: headless: True diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 7c64dab..290dac6 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -1,3 +1,6 @@ +import heapq +import io +import math import os import pathlib import random @@ -163,6 +166,52 @@ def print_dashboard(stats, init_performance, performance): time.sleep(1 / 20) +# this is a hack to make pufferlib's async reset support kwargs +def async_reset_mp(self, seed=None, **kwargs): + pufferlib.vectorization.reset_precheck(self) + + if seed is None: + for idx, pipe in enumerate(self.send_pipes): + pipe.send( + ( + "reset", + [], + { + k: v[self.envs_per_worker * idx : self.envs_per_worker * (idx + 1)] + for k, v in kwargs.items() + }, + ) + ) + else: + for idx, pipe in enumerate(self.send_pipes): + pipe.send( + ( + "reset", + [], + ( + {"seed": seed + idx} + | { + k: v[self.envs_per_worker * idx : self.envs_per_worker * (idx + 1)] + for k, v in kwargs.items() + } + ), + ) + ) + + +def async_reset_serial(self, seed=None, **kwargs): + pufferlib.vectorization.reset_precheck(self) + if seed is None: + self.data = [ + e.reset({k: v[idx] for k, v in kwargs.items()}) for idx, e in enumerate(self.multi_envs) + ] + else: + self.data = [ + e.reset(seed=seed + idx, **{k: v[idx] for k, v in kwargs.items()}) + for idx, e in enumerate(self.multi_envs) + ] + + # TODO: Make this an unfrozen dataclass with a post_init? class CleanPuffeRL: def __init__( @@ -216,6 +265,10 @@ def __init__( env_pool=config.env_pool, mask_agents=True, ) + if isinstance(self.pool, pufferlib.vectorization.Serial): + self.pool.async_reset = async_reset_serial + elif isinstance(self.pool, pufferlib.vectorization.Multiprocessing): + self.pool.async_reset = async_reset_mp obs_shape = self.pool.single_observation_space.shape atn_shape = self.pool.single_action_space.shape @@ -349,6 +402,37 @@ def evaluate(self): ) self.log = False + # now for a tricky bit: + # if we have swarm_frequency, we will take the top swarm_pct envs and evenly distribute + # their states to the bottom 90%. + # we do this here so the environment can remain "pure" + if ( + hasattr(self.config, "swarm_frequency") + and hasattr(self.config, "swarm_pct") + and self.update % self.config.swarm_frequency == 0 + ): + # collect the top swarm_pct % of envs + largest = set( + x[0] + for x in heapq.nlargest( + math.ceil(self.config.num_envs * self.config.swarm_pct), + enumerate(self.infos["learner"]["reward/event"]), + key=lambda x: x[1], + ) + ) + # TODO: Not every one of these learners will have a recently saved state. + # Find a good way to tell them to make a saved state even if it is with a reset or get + reset_states = [ + random.choice(largest) if i not in largest else i + for i in range(self.config.num_envs) + ] + # unsure if bytes io can deep copy so I'm gonna make a bunch of copies here + for i in range(self.config.num_envs): + reset_states[i] = io.BytesIO(self.infos["learner"]["state"][i].read()) + self.infos["learner"]["state"][i].seek(0) + # now async reset the envs + self.pool.async_reset(self.config.seed, reset_states=reset_states) + self.policy_pool.update_policies() env_profiler = pufferlib.utils.Profiler() inference_profiler = pufferlib.utils.Profiler() @@ -442,9 +526,6 @@ def evaluate(self): with env_profiler: self.pool.send(actions) - if "state" in self.infos: - breakpoint() - eval_profiler.stop() # Now that we initialized the model, we can get the number of parameters @@ -454,6 +535,7 @@ def evaluate(self): self.total_agent_steps += padded_steps_collected new_step = np.mean(self.infos["learner"]["stats/step"]) + if new_step > self.global_step: self.global_step = new_step self.log = True