Skip to content

Commit

Permalink
I shouldn't need the second flush
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Oct 21, 2024
1 parent 61ab4c4 commit 603fd8b
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 20 deletions.
16 changes: 0 additions & 16 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,6 @@ def __post_init__(self):
self.archive_path.mkdir(exist_ok=False)
print(f"Will archive states to {self.archive_path}")

if self.sqlite_db:
self.conn = sqlite3.connect(self.sqlite_db)
self.cur = self.conn.cursor()

@pufferlib.utils.profile
def evaluate(self):
# states are managed separately so dont worry about deleting them
Expand Down Expand Up @@ -384,18 +380,6 @@ def evaluate(self):
if all(not reset for reset, env_id in resets if env_id in key_set):
break
time.sleep(0.5)
# Flush any waiting workers
while self.vecenv.waiting_workers:
worker = self.vecenv.waiting_workers.pop(0)
sem = self.vecenv.buf.semaphores[worker]
if sem >= pufferlib.vector.MAIN:
self.vecenv.ready_workers.append(worker)
else:
self.vecenv.waiting_workers.append(worker)
self.vecenv.waiting_workers, self.vecenv.ready_workers = (
self.vecenv.ready_workers,
self.vecenv.waiting_workers,
)
if self.config.async_wrapper:
for key, state in zip(self.event_tracker.keys(), new_states):
self.env_recv_queues[key].put(state)
Expand Down
2 changes: 1 addition & 1 deletion pokemonred_puffer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def train(

sqlite_context = nullcontext
if config.train.get("sqlite_wrapper", False):
sqlite_context = NamedTemporaryFile
sqlite_context = functools.partial(NamedTemporaryFile, suffix="sqlite")

with sqlite_context() as sqlite_db:
db_filename = None
Expand Down
7 changes: 4 additions & 3 deletions pokemonred_puffer/wrappers/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from os import PathLike
import os
import sqlite3
from typing import Any

Expand All @@ -19,10 +20,10 @@ def __init__(
cur = conn.cursor()
cur.execute(
"""
INSERT INTO states(env_id, pyboy_state, reset)
VALUES(?, ?, ?)
INSERT INTO states(env_id, pyboy_state, reset, pid)
VALUES(?, ?, ?, ?)
""",
(self.env.unwrapped.env_id, b"", 0),
(self.env.unwrapped.env_id, b"", 0, os.getpid()),
)
print(f"Initialized sqlite row {self.env.unwrapped.env_id}")

Expand Down

0 comments on commit 603fd8b

Please sign in to comment.