diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index e1eca8f..b5065f7 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -351,20 +351,24 @@ def evaluate(self): ) ] if self.sqlite_db: - self.cur.executemany( - """ - UPDATE states - SET state=? - SET reset=? - WHERE env_id=? - """, - tuple( - [ - (state, True, env_id) - for state, env_id in zip(new_states, self.event_tracker.keys()) - ] - ), - ) + with sqlite3.connect(self.sqlite_db) as conn: + cur = conn.cursor() + cur.executemany( + """ + UPDATE states + SET state=? + SET reset=? + WHERE env_id=? + """, + tuple( + [ + (state, True, env_id) + for state, env_id in zip( + new_states, self.event_tracker.keys() + ) + ] + ), + ) self.vecenv.async_reset() if self.config.async_wrapper: for key, state in zip(self.event_tracker.keys(), new_states): diff --git a/pokemonred_puffer/train.py b/pokemonred_puffer/train.py index 215bff3..234fc06 100644 --- a/pokemonred_puffer/train.py +++ b/pokemonred_puffer/train.py @@ -359,12 +359,11 @@ def train( db_filename = None if config.train.get("sqlite_wrapper", False): db_filename = sqlite_db.name - conn = sqlite3.connect(db_filename) - cur = conn.cursor() - cur.execute( - "CREATE TABLE states(env_id INT PRIMARY_KEY, pyboy_state BLOB, reset BOOLEAN);" - ) - cur.close() + with sqlite3.connect(db_filename) as conn: + cur = conn.cursor() + cur.execute( + "CREATE TABLE states(env_id INT PRIMARY_KEY, pyboy_state BLOB, reset BOOLEAN);" + ) vecenv = pufferlib.vector.make( env_creator, diff --git a/pokemonred_puffer/wrappers/sqlite.py b/pokemonred_puffer/wrappers/sqlite.py index a6ff7e8..f110376 100644 --- a/pokemonred_puffer/wrappers/sqlite.py +++ b/pokemonred_puffer/wrappers/sqlite.py @@ -14,37 +14,40 @@ def __init__( database: str | bytes | PathLike[str] | PathLike[bytes], ): super().__init__(env) - self.conn = sqlite3.connect(database) - self.cur = self.conn.cursor() - self.cur.execute( - """ - INSERT INTO states(env_id, pyboy_state, reset) - VALUES(?, ?, ?) - """, - (self.env.unwrapped.env_id, b"", False), - ) + self.database = database + with sqlite3.connect(database) as conn: + cur = conn.cursor() + cur.execute( + """ + INSERT INTO states(env_id, pyboy_state, reset) + VALUES(?, ?, ?) + """, + (self.env.unwrapped.env_id, b"", False), + ) def reset(self, seed: int | None = None, options: dict[str, Any] | None = None): - reset, pyboy_state = self.cur.execute( - """ - SELECT reset, pyboy_state - FROM states - WHERE env_id = ? - """, - (self.env.unwrapped.env_id,), - ).fetchone() - if reset: - if options: - options["state"] = pyboy_state - else: - options = {"state": pyboy_state} - res = self.env.reset(seed=seed, options=options) - self.cur.execute( - """ - UPDATE states - SET reset = False - WHERE env_id = ? - """, - (self.env.unwrapped.env_id,), - ) + with sqlite3.connect(self.database) as conn: + cur = conn.cursor() + reset, pyboy_state = cur.execute( + """ + SELECT reset, pyboy_state + FROM states + WHERE env_id = ? + """, + (self.env.unwrapped.env_id,), + ).fetchone() + if reset: + if options: + options["state"] = pyboy_state + else: + options = {"state": pyboy_state} + res = self.env.reset(seed=seed, options=options) + cur.execute( + """ + UPDATE states + SET reset = False + WHERE env_id = ? + """, + (self.env.unwrapped.env_id,), + ) return res