diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 373f5bc..2e4f5d5 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -33,6 +33,7 @@ from pokemonred_puffer.eval import make_pokemon_red_overlay from pokemonred_puffer.global_map import GLOBAL_MAP_SHAPE from pokemonred_puffer.profile import Profile, Utilization +from pokemonred_puffer.wrappers.sqlite import SqliteStateResetWrapper pyximport.install(setup_args={"include_dirs": np.get_include()}) from pokemonred_puffer.c_gae import compute_gae # type: ignore # noqa: E402 @@ -379,37 +380,39 @@ def evaluate(self): ) ] if self.sqlite_db: - with sqlite3.connect(self.sqlite_db) as conn: - cur = conn.cursor() - cur.executemany( - """ - UPDATE states - SET pyboy_state=:state, - reset=:reset - WHERE env_id=:env_id - """, - tuple( - [ - {"state": state, "reset": 1, "env_id": env_id} - for state, env_id in zip( - new_states, self.event_tracker.keys() - ) - ] - ), - ) + with SqliteStateResetWrapper.DB_LOCK: + with sqlite3.connect(self.sqlite_db) as conn: + cur = conn.cursor() + cur.executemany( + """ + UPDATE states + SET pyboy_state=:state, + reset=:reset + WHERE env_id=:env_id + """, + tuple( + [ + {"state": state, "reset": 1, "env_id": env_id} + for state, env_id in zip( + new_states, self.event_tracker.keys() + ) + ] + ), + ) self.vecenv.async_reset() # drain any INFO key_set = self.event_tracker.keys() while True: # We connect each time just in case we block the wrappers - with sqlite3.connect(self.sqlite_db) as conn: - cur = conn.cursor() - resets = cur.execute( - """ - SELECT reset, env_id - FROM states - """, - ).fetchall() + with SqliteStateResetWrapper.DB_LOCK: + with sqlite3.connect(self.sqlite_db) as conn: + cur = conn.cursor() + resets = cur.execute( + """ + SELECT reset, env_id + FROM states + """, + ).fetchall() if all(not reset for reset, env_id in resets if env_id in key_set): break time.sleep(0.5) diff --git a/pokemonred_puffer/wrappers/sqlite.py b/pokemonred_puffer/wrappers/sqlite.py index 3becf5e..ae188bc 100644 --- a/pokemonred_puffer/wrappers/sqlite.py +++ b/pokemonred_puffer/wrappers/sqlite.py @@ -1,4 +1,5 @@ from os import PathLike +import multiprocessing import os import sqlite3 from typing import Any @@ -9,6 +10,8 @@ class SqliteStateResetWrapper(gym.Wrapper): + DB_LOCK = multiprocessing.Lock() + def __init__( self, env: RedGymEnv, @@ -16,40 +19,42 @@ def __init__( ): super().__init__(env) self.database = database - with sqlite3.connect(database) as conn: - cur = conn.cursor() - cur.execute( - """ - INSERT INTO states(env_id, pyboy_state, reset, pid) - VALUES(?, ?, ?, ?) - """, - (self.env.unwrapped.env_id, b"", 0, os.getpid()), - ) + with SqliteStateResetWrapper.DB_LOCK: + with sqlite3.connect(database) as conn: + cur = conn.cursor() + cur.execute( + """ + INSERT INTO states(env_id, pyboy_state, reset, pid) + VALUES(?, ?, ?, ?) + """, + (self.env.unwrapped.env_id, b"", 0, os.getpid()), + ) print(f"Initialized sqlite row {self.env.unwrapped.env_id}") def reset(self, seed: int | None = None, options: dict[str, Any] | None = None): - with sqlite3.connect(self.database) as conn: - cur = conn.cursor() - reset, pyboy_state = cur.execute( - """ - SELECT reset, pyboy_state - FROM states - WHERE env_id = ? - """, - (self.env.unwrapped.env_id,), - ).fetchone() - if reset: - if options: - options["state"] = pyboy_state - else: - options = {"state": pyboy_state} - res = self.env.reset(seed=seed, options=options) - cur.execute( - """ - UPDATE states - SET reset = 0 - WHERE env_id = ? - """, - (self.env.unwrapped.env_id,), - ) + with SqliteStateResetWrapper.DB_LOCK: + with sqlite3.connect(self.database) as conn: + cur = conn.cursor() + reset, pyboy_state = cur.execute( + """ + SELECT reset, pyboy_state + FROM states + WHERE env_id = ? + """, + (self.env.unwrapped.env_id,), + ).fetchone() + if reset: + if options: + options["state"] = pyboy_state + else: + options = {"state": pyboy_state} + res = self.env.reset(seed=seed, options=options) + cur.execute( + """ + UPDATE states + SET reset = 0 + WHERE env_id = ? + """, + (self.env.unwrapped.env_id,), + ) return res