diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 3a4eefb..b4ee817 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -370,27 +370,27 @@ def evaluate(self): ), ) self.vecenv.async_reset() + with sqlite3.connect(self.sqlite_db) as conn: + while True: + resets = cur.executemany( + """ + SELECT reset + FROM states + WHERE env_id=:env_id + """, + tuple( + [{"env_id": env_id} for env_id in self.event_tracker.keys()] + ), + ).fetchall() + if all(not reset for reset in resets): + break + time.sleep(0.5) if self.config.async_wrapper: for key, state in zip(self.event_tracker.keys(), new_states): self.env_recv_queues[key].put(state) for key in self.event_tracker.keys(): # print(f"\tWaiting for message from env-id {key}") self.env_send_queues[key].get() - # Alternative: reopoen sqlite3 connection with - # SELECT count(*) FROM states WHERE reset=False - # == SELECT count(*) - # Flush any waiting workers - while self.vecenv.waiting_workers: - worker = self.vecenv.waiting_workers.pop(0) - sem = self.vecenv.buf.semaphores[worker] - if sem >= pufferlib.vector.MAIN: - self.vecenv.ready_workers.append(worker) - else: - self.vecenv.waiting_workers.append(worker) - self.vecenv.ready_workers, self.vecenv.waiting_workers = ( - self.vecenv.waiting_workers, - self.vecenv.ready_workers, - ) print( f"State migration to {self.archive_path}/{str(hash(new_state_key))} complete"