Skip to content

Commit

Permalink
handle the scenario until infos gets updated
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Apr 14, 2024
1 parent 7185087 commit 62c50c5
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 25 deletions.
14 changes: 9 additions & 5 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,27 @@ debug:
headless: True
stream_wrapper: False
init_state: cut
max_steps: 4
train:
device: cpu
compile: False
compile_mode: default
num_envs: 10
envs_per_worker: 1
envs_per_batch: 1
batch_size: 128
batch_rows: 32
bptt_horizon: 4
batch_size: 8
batch_rows: 4
bptt_horizon: 2
total_timesteps: 100_000_000
save_checkpoint: True
checkpoint_interval: 4
save_overlay: True
overlay_interval: 4
verbose: True
verbose: False
env_pool: False
log_frequency: 5000
load_optimizer_state: False
swarm_frequency: 100
swarm_frequency: 5
swarm_keep_pct: .8

env:
Expand Down Expand Up @@ -88,6 +89,9 @@ train:
pool_kernel: [0]
load_optimizer_state: False

swarm_frequency: 500
swarm_keep_pct: .8

wrappers:
baseline:
- stream_wrapper.StreamWrapper:
Expand Down
24 changes: 11 additions & 13 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,21 +374,19 @@ def evaluate(self):
key=lambda x: x[1],
)
]
# TODO: Not every one of these learners will have a recently saved state.
# Find a good way to tell them to make a saved state even if it is with a reset or get
reset_states = [
random.choice(largest) if i not in largest else i
for i in range(self.config.num_envs)
]
# Do this with + because f-strings are tricky. Can do nested f-string in a later python.
print("Migrating states:")
for i, n in enumerate(reset_states):
print(
f'\t {i} -> {n}, event scores: {self.infos["learner"]["reward/event"][i]} -> {self.infos["learner"]["reward/event"][n]}'
)
for i in range(self.config.num_envs):
self.env_recv_queues[i].put(self.infos["learner"]["state"][reset_states[i]])
waiting_for = []
# Need a way not to reset the env id counter for the driver env
# Until then env ids are 1-indexed
for i in range(self.config.num_envs):
if i not in largest:
new_state = random.choice(largest)
print(
f'\t {i+1} -> {new_state+1}, event scores: {self.infos["learner"]["reward/event"][i]} -> {self.infos["learner"]["reward/event"][new_state]}'
)
self.env_recv_queues[i + 1].put(self.infos["learner"]["state"][new_state])
waiting_for.append(i + 1)
for i in waiting_for:
self.env_send_queues[i].get()

self.policy_pool.update_policies()
Expand Down
5 changes: 3 additions & 2 deletions pokemonred_puffer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,9 @@ def train(
env_creator: Callable,
agent_creator: Callable[[gym.Env, pufferlib.namespace], pufferlib.models.Policy],
):
env_send_queues = [Queue() for _ in range(args.train.num_envs)]
env_recv_queues = [Queue() for _ in range(args.train.num_envs)]
# TODO: Remove the +1 once the driver env doesn't permanently increase the env id
env_send_queues = [Queue() for _ in range(args.train.num_envs + 1)]
env_recv_queues = [Queue() for _ in range(args.train.num_envs + 1)]
with CleanPuffeRL(
config=args.train,
agent_creator=agent_creator,
Expand Down
15 changes: 10 additions & 5 deletions pokemonred_puffer/wrappers/async_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ class AsyncWrapper(gym.Wrapper):
def __init__(self, env: RedGymEnv, send_queues: list[Queue], recv_queues: list[Queue]):
super().__init__(env)
# We need to -1 because the env id is one offset due to puffer's driver env
self.send_queue = send_queues[self.env.unwrapped.env_id - 1]
self.recv_queue = recv_queues[self.env.unwrapped.env_id - 1]
self.send_queue = send_queues[self.env.unwrapped.env_id]
self.recv_queue = recv_queues[self.env.unwrapped.env_id]
print(f"Initialized queues for {self.env.unwrapped.env_id}")
# Now we will spawn a thread that will listen for updates
# and send back when the new state has been loaded
# this is a slow process and should rarely happen.
Expand All @@ -19,6 +20,10 @@ def __init__(self, env: RedGymEnv, send_queues: list[Queue], recv_queues: list[Q
# TODO: Figure out if there's a safe way to exit the thread

def update(self):
while new_state := self.recv_queue.get():
self.env.update_state(new_state)
self.send_queue.put(self.env.unwrapped.env_id - 1)
while True:
new_state = self.recv_queue.get()
if new_state == b"":
print(f"invalid state for {self.env.unwrapped.env_id} skipping...")
else:
self.env.unwrapped.update_state(new_state)
self.send_queue.put(self.env.unwrapped.env_id)

0 comments on commit 62c50c5

Please sign in to comment.