From 0aac8b51c9389e4c6fe00cf9ed47ca8a46d8f2a0 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Sun, 20 Oct 2024 16:13:35 -0400 Subject: [PATCH] One more try. Double check --- pokemonred_puffer/cleanrl_puffer.py | 14 +++++++++++++- pokemonred_puffer/wrappers/sqlite.py | 4 ++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 4306794..552f2e1 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -362,7 +362,7 @@ def evaluate(self): """, tuple( [ - {"state": state, "reset": True, "env_id": env_id} + {"state": state, "reset": 1, "env_id": env_id} for state, env_id in zip( new_states, self.event_tracker.keys() ) @@ -384,6 +384,18 @@ def evaluate(self): if all(not reset for reset, env_id in resets if env_id in key_set): break time.sleep(0.5) + # Flush any waiting workers + while self.vecenv.waiting_workers: + worker = self.vecenv.waiting_workers.pop(0) + sem = self.vecenv.buf.semaphores[worker] + if sem >= pufferlib.vector.MAIN: + self.vecenv.ready_workers.append(worker) + else: + self.vecenv.waiting_workers.append(worker) + self.vecenv.waiting_workers, self.vecenv.ready_workers = ( + self.vecenv.ready_workers, + self.vecenv.waiting_workers, + ) if self.config.async_wrapper: for key, state in zip(self.event_tracker.keys(), new_states): self.env_recv_queues[key].put(state) diff --git a/pokemonred_puffer/wrappers/sqlite.py b/pokemonred_puffer/wrappers/sqlite.py index dbe3195..6afba2a 100644 --- a/pokemonred_puffer/wrappers/sqlite.py +++ b/pokemonred_puffer/wrappers/sqlite.py @@ -22,7 +22,7 @@ def __init__( INSERT INTO states(env_id, pyboy_state, reset) VALUES(?, ?, ?) """, - (self.env.unwrapped.env_id, b"", False), + (self.env.unwrapped.env_id, b"", 0), ) print(f"Initialized sqlite row {self.env.unwrapped.env_id}") @@ -46,7 +46,7 @@ def reset(self, seed: int | None = None, options: dict[str, Any] | None = None): cur.execute( """ UPDATE states - SET reset = False + SET reset = 0 WHERE env_id = ? """, (self.env.unwrapped.env_id,),