Skip to content

Commit

Permalink
commits
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Oct 18, 2024
1 parent b6d6179 commit c6c67ab
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 51 deletions.
32 changes: 18 additions & 14 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,20 +351,24 @@ def evaluate(self):
)
]
if self.sqlite_db:
self.cur.executemany(
"""
UPDATE states
SET state=?
SET reset=?
WHERE env_id=?
""",
tuple(
[
(state, True, env_id)
for state, env_id in zip(new_states, self.event_tracker.keys())
]
),
)
with sqlite3.connect(self.sqlite_db) as conn:
cur = conn.cursor()
cur.executemany(
"""
UPDATE states
SET state=?
SET reset=?
WHERE env_id=?
""",
tuple(
[
(state, True, env_id)
for state, env_id in zip(
new_states, self.event_tracker.keys()
)
]
),
)
self.vecenv.async_reset()
if self.config.async_wrapper:
for key, state in zip(self.event_tracker.keys(), new_states):
Expand Down
11 changes: 5 additions & 6 deletions pokemonred_puffer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,11 @@ def train(
db_filename = None
if config.train.get("sqlite_wrapper", False):
db_filename = sqlite_db.name
conn = sqlite3.connect(db_filename)
cur = conn.cursor()
cur.execute(
"CREATE TABLE states(env_id INT PRIMARY_KEY, pyboy_state BLOB, reset BOOLEAN);"
)
cur.close()
with sqlite3.connect(db_filename) as conn:
cur = conn.cursor()
cur.execute(
"CREATE TABLE states(env_id INT PRIMARY_KEY, pyboy_state BLOB, reset BOOLEAN);"
)

vecenv = pufferlib.vector.make(
env_creator,
Expand Down
65 changes: 34 additions & 31 deletions pokemonred_puffer/wrappers/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,40 @@ def __init__(
database: str | bytes | PathLike[str] | PathLike[bytes],
):
super().__init__(env)
self.conn = sqlite3.connect(database)
self.cur = self.conn.cursor()
self.cur.execute(
"""
INSERT INTO states(env_id, pyboy_state, reset)
VALUES(?, ?, ?)
""",
(self.env.unwrapped.env_id, b"", False),
)
self.database = database
with sqlite3.connect(database) as conn:
cur = conn.cursor()
cur.execute(
"""
INSERT INTO states(env_id, pyboy_state, reset)
VALUES(?, ?, ?)
""",
(self.env.unwrapped.env_id, b"", False),
)

def reset(self, seed: int | None = None, options: dict[str, Any] | None = None):
reset, pyboy_state = self.cur.execute(
"""
SELECT reset, pyboy_state
FROM states
WHERE env_id = ?
""",
(self.env.unwrapped.env_id,),
).fetchone()
if reset:
if options:
options["state"] = pyboy_state
else:
options = {"state": pyboy_state}
res = self.env.reset(seed=seed, options=options)
self.cur.execute(
"""
UPDATE states
SET reset = False
WHERE env_id = ?
""",
(self.env.unwrapped.env_id,),
)
with sqlite3.connect(self.database) as conn:
cur = conn.cursor()
reset, pyboy_state = cur.execute(
"""
SELECT reset, pyboy_state
FROM states
WHERE env_id = ?
""",
(self.env.unwrapped.env_id,),
).fetchone()
if reset:
if options:
options["state"] = pyboy_state
else:
options = {"state": pyboy_state}
res = self.env.reset(seed=seed, options=options)
cur.execute(
"""
UPDATE states
SET reset = False
WHERE env_id = ?
""",
(self.env.unwrapped.env_id,),
)
return res

0 comments on commit c6c67ab

Please sign in to comment.