Skip to content

Commit

Permalink
One more try. Double check
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Oct 20, 2024
1 parent 97b36a9 commit 0aac8b5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
14 changes: 13 additions & 1 deletion pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def evaluate(self):
""",
tuple(
[
{"state": state, "reset": True, "env_id": env_id}
{"state": state, "reset": 1, "env_id": env_id}
for state, env_id in zip(
new_states, self.event_tracker.keys()
)
Expand All @@ -384,6 +384,18 @@ def evaluate(self):
if all(not reset for reset, env_id in resets if env_id in key_set):
break
time.sleep(0.5)
# Flush any waiting workers
while self.vecenv.waiting_workers:
worker = self.vecenv.waiting_workers.pop(0)
sem = self.vecenv.buf.semaphores[worker]
if sem >= pufferlib.vector.MAIN:
self.vecenv.ready_workers.append(worker)
else:
self.vecenv.waiting_workers.append(worker)
self.vecenv.waiting_workers, self.vecenv.ready_workers = (
self.vecenv.ready_workers,
self.vecenv.waiting_workers,
)
if self.config.async_wrapper:
for key, state in zip(self.event_tracker.keys(), new_states):
self.env_recv_queues[key].put(state)
Expand Down
4 changes: 2 additions & 2 deletions pokemonred_puffer/wrappers/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
INSERT INTO states(env_id, pyboy_state, reset)
VALUES(?, ?, ?)
""",
(self.env.unwrapped.env_id, b"", False),
(self.env.unwrapped.env_id, b"", 0),
)
print(f"Initialized sqlite row {self.env.unwrapped.env_id}")

Expand All @@ -46,7 +46,7 @@ def reset(self, seed: int | None = None, options: dict[str, Any] | None = None):
cur.execute(
"""
UPDATE states
SET reset = False
SET reset = 0
WHERE env_id = ?
""",
(self.env.unwrapped.env_id,),
Expand Down

0 comments on commit 0aac8b5

Please sign in to comment.