Skip to content

Commit

Permalink
Add some locks. Hopefully things wont slow down too much
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Nov 2, 2024
1 parent f7e7f8b commit 6c5143a
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 59 deletions.
55 changes: 29 additions & 26 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pokemonred_puffer.eval import make_pokemon_red_overlay
from pokemonred_puffer.global_map import GLOBAL_MAP_SHAPE
from pokemonred_puffer.profile import Profile, Utilization
from pokemonred_puffer.wrappers.sqlite import SqliteStateResetWrapper

pyximport.install(setup_args={"include_dirs": np.get_include()})
from pokemonred_puffer.c_gae import compute_gae # type: ignore # noqa: E402
Expand Down Expand Up @@ -379,37 +380,39 @@ def evaluate(self):
)
]
if self.sqlite_db:
with sqlite3.connect(self.sqlite_db) as conn:
cur = conn.cursor()
cur.executemany(
"""
UPDATE states
SET pyboy_state=:state,
reset=:reset
WHERE env_id=:env_id
""",
tuple(
[
{"state": state, "reset": 1, "env_id": env_id}
for state, env_id in zip(
new_states, self.event_tracker.keys()
)
]
),
)
with SqliteStateResetWrapper.DB_LOCK:
with sqlite3.connect(self.sqlite_db) as conn:
cur = conn.cursor()
cur.executemany(
"""
UPDATE states
SET pyboy_state=:state,
reset=:reset
WHERE env_id=:env_id
""",
tuple(
[
{"state": state, "reset": 1, "env_id": env_id}
for state, env_id in zip(
new_states, self.event_tracker.keys()
)
]
),
)
self.vecenv.async_reset()
# drain any INFO
key_set = self.event_tracker.keys()
while True:
# We connect each time just in case we block the wrappers
with sqlite3.connect(self.sqlite_db) as conn:
cur = conn.cursor()
resets = cur.execute(
"""
SELECT reset, env_id
FROM states
""",
).fetchall()
with SqliteStateResetWrapper.DB_LOCK:
with sqlite3.connect(self.sqlite_db) as conn:
cur = conn.cursor()
resets = cur.execute(
"""
SELECT reset, env_id
FROM states
""",
).fetchall()
if all(not reset for reset, env_id in resets if env_id in key_set):
break
time.sleep(0.5)
Expand Down
71 changes: 38 additions & 33 deletions pokemonred_puffer/wrappers/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from os import PathLike
import multiprocessing
import os
import sqlite3
from typing import Any
Expand All @@ -9,47 +10,51 @@


class SqliteStateResetWrapper(gym.Wrapper):
DB_LOCK = multiprocessing.Lock()

def __init__(
self,
env: RedGymEnv,
database: str | bytes | PathLike[str] | PathLike[bytes],
):
super().__init__(env)
self.database = database
with sqlite3.connect(database) as conn:
cur = conn.cursor()
cur.execute(
"""
INSERT INTO states(env_id, pyboy_state, reset, pid)
VALUES(?, ?, ?, ?)
""",
(self.env.unwrapped.env_id, b"", 0, os.getpid()),
)
with SqliteStateResetWrapper.DB_LOCK:
with sqlite3.connect(database) as conn:
cur = conn.cursor()
cur.execute(
"""
INSERT INTO states(env_id, pyboy_state, reset, pid)
VALUES(?, ?, ?, ?)
""",
(self.env.unwrapped.env_id, b"", 0, os.getpid()),
)
print(f"Initialized sqlite row {self.env.unwrapped.env_id}")

def reset(self, seed: int | None = None, options: dict[str, Any] | None = None):
with sqlite3.connect(self.database) as conn:
cur = conn.cursor()
reset, pyboy_state = cur.execute(
"""
SELECT reset, pyboy_state
FROM states
WHERE env_id = ?
""",
(self.env.unwrapped.env_id,),
).fetchone()
if reset:
if options:
options["state"] = pyboy_state
else:
options = {"state": pyboy_state}
res = self.env.reset(seed=seed, options=options)
cur.execute(
"""
UPDATE states
SET reset = 0
WHERE env_id = ?
""",
(self.env.unwrapped.env_id,),
)
with SqliteStateResetWrapper.DB_LOCK:
with sqlite3.connect(self.database) as conn:
cur = conn.cursor()
reset, pyboy_state = cur.execute(
"""
SELECT reset, pyboy_state
FROM states
WHERE env_id = ?
""",
(self.env.unwrapped.env_id,),
).fetchone()
if reset:
if options:
options["state"] = pyboy_state
else:
options = {"state": pyboy_state}
res = self.env.reset(seed=seed, options=options)
cur.execute(
"""
UPDATE states
SET reset = 0
WHERE env_id = ?
""",
(self.env.unwrapped.env_id,),
)
return res

0 comments on commit 6c5143a

Please sign in to comment.