diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 4306794..552f2e1 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -362,7 +362,7 @@ def evaluate(self): """, tuple( [ - {"state": state, "reset": True, "env_id": env_id} + {"state": state, "reset": 1, "env_id": env_id} for state, env_id in zip( new_states, self.event_tracker.keys() ) @@ -384,6 +384,18 @@ def evaluate(self): if all(not reset for reset, env_id in resets if env_id in key_set): break time.sleep(0.5) + # Flush any waiting workers + while self.vecenv.waiting_workers: + worker = self.vecenv.waiting_workers.pop(0) + sem = self.vecenv.buf.semaphores[worker] + if sem >= pufferlib.vector.MAIN: + self.vecenv.ready_workers.append(worker) + else: + self.vecenv.waiting_workers.append(worker) + self.vecenv.waiting_workers, self.vecenv.ready_workers = ( + self.vecenv.ready_workers, + self.vecenv.waiting_workers, + ) if self.config.async_wrapper: for key, state in zip(self.event_tracker.keys(), new_states): self.env_recv_queues[key].put(state) diff --git a/pokemonred_puffer/wrappers/sqlite.py b/pokemonred_puffer/wrappers/sqlite.py index dbe3195..6afba2a 100644 --- a/pokemonred_puffer/wrappers/sqlite.py +++ b/pokemonred_puffer/wrappers/sqlite.py @@ -22,7 +22,7 @@ def __init__( INSERT INTO states(env_id, pyboy_state, reset) VALUES(?, ?, ?) """, - (self.env.unwrapped.env_id, b"", False), + (self.env.unwrapped.env_id, b"", 0), ) print(f"Initialized sqlite row {self.env.unwrapped.env_id}") @@ -46,7 +46,7 @@ def reset(self, seed: int | None = None, options: dict[str, Any] | None = None): cur.execute( """ UPDATE states - SET reset = False + SET reset = 0 WHERE env_id = ? """, (self.env.unwrapped.env_id,),