Skip to content

Commit

Permalink
swarm every n updates
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Apr 8, 2024
1 parent 748de65 commit e4a528a
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 3 deletions.
2 changes: 2 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ debug:
env_pool: False
log_frequency: 5000
load_optimizer_state: False
swarm_frequency: 1
swarm_pct: 10

env:
headless: True
Expand Down
88 changes: 85 additions & 3 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import heapq
import io
import math
import os
import pathlib
import random
Expand Down Expand Up @@ -163,6 +166,52 @@ def print_dashboard(stats, init_performance, performance):
time.sleep(1 / 20)


# this is a hack to make pufferlib's async reset support kwargs
def async_reset_mp(self, seed=None, **kwargs):
pufferlib.vectorization.reset_precheck(self)

if seed is None:
for idx, pipe in enumerate(self.send_pipes):
pipe.send(
(
"reset",
[],
{
k: v[self.envs_per_worker * idx : self.envs_per_worker * (idx + 1)]
for k, v in kwargs.items()
},
)
)
else:
for idx, pipe in enumerate(self.send_pipes):
pipe.send(
(
"reset",
[],
(
{"seed": seed + idx}
| {
k: v[self.envs_per_worker * idx : self.envs_per_worker * (idx + 1)]
for k, v in kwargs.items()
}
),
)
)


def async_reset_serial(self, seed=None, **kwargs):
pufferlib.vectorization.reset_precheck(self)
if seed is None:
self.data = [
e.reset({k: v[idx] for k, v in kwargs.items()}) for idx, e in enumerate(self.multi_envs)
]
else:
self.data = [
e.reset(seed=seed + idx, **{k: v[idx] for k, v in kwargs.items()})
for idx, e in enumerate(self.multi_envs)
]


# TODO: Make this an unfrozen dataclass with a post_init?
class CleanPuffeRL:
def __init__(
Expand Down Expand Up @@ -216,6 +265,10 @@ def __init__(
env_pool=config.env_pool,
mask_agents=True,
)
if isinstance(self.pool, pufferlib.vectorization.Serial):
self.pool.async_reset = async_reset_serial
elif isinstance(self.pool, pufferlib.vectorization.Multiprocessing):
self.pool.async_reset = async_reset_mp

obs_shape = self.pool.single_observation_space.shape
atn_shape = self.pool.single_action_space.shape
Expand Down Expand Up @@ -349,6 +402,37 @@ def evaluate(self):
)
self.log = False

# now for a tricky bit:
# if we have swarm_frequency, we will take the top swarm_pct envs and evenly distribute
# their states to the bottom 90%.
# we do this here so the environment can remain "pure"
if (
hasattr(self.config, "swarm_frequency")
and hasattr(self.config, "swarm_pct")
and self.update % self.config.swarm_frequency == 0
):
# collect the top swarm_pct % of envs
largest = set(
x[0]
for x in heapq.nlargest(
math.ceil(self.config.num_envs * self.config.swarm_pct),
enumerate(self.infos["learner"]["reward/event"]),
key=lambda x: x[1],
)
)
# TODO: Not every one of these learners will have a recently saved state.
# Find a good way to tell them to make a saved state even if it is with a reset or get
reset_states = [
random.choice(largest) if i not in largest else i
for i in range(self.config.num_envs)
]
# unsure if bytes io can deep copy so I'm gonna make a bunch of copies here
for i in range(self.config.num_envs):
reset_states[i] = io.BytesIO(self.infos["learner"]["state"][i].read())
self.infos["learner"]["state"][i].seek(0)
# now async reset the envs
self.pool.async_reset(self.config.seed, reset_states=reset_states)

self.policy_pool.update_policies()
env_profiler = pufferlib.utils.Profiler()
inference_profiler = pufferlib.utils.Profiler()
Expand Down Expand Up @@ -442,9 +526,6 @@ def evaluate(self):
with env_profiler:
self.pool.send(actions)

if "state" in self.infos:
breakpoint()

eval_profiler.stop()

# Now that we initialized the model, we can get the number of parameters
Expand All @@ -454,6 +535,7 @@ def evaluate(self):

self.total_agent_steps += padded_steps_collected
new_step = np.mean(self.infos["learner"]["stats/step"])

if new_step > self.global_step:
self.global_step = new_step
self.log = True
Expand Down

0 comments on commit e4a528a

Please sign in to comment.