Skip to content

Commit

Permalink
Add sqlite wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Oct 18, 2024
1 parent d105d30 commit b6d6179
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 70 deletions.
12 changes: 7 additions & 5 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ wandb:

debug:
env:
headless: True
headless: False
stream_wrapper: False
init_state: "victory_road_5"
state_dir: pyboy_states
Expand All @@ -25,13 +25,13 @@ debug:
num_envs: 1
envs_per_worker: 1
num_workers: 1
env_batch_size: 4
env_batch_size: 128
zero_copy: False
batch_size: 4
minibatch_size: 4
batch_size: 1024
minibatch_size: 128
batch_rows: 4
bptt_horizon: 2
total_timesteps: 16
total_timesteps: 1_000_000
save_checkpoint: True
checkpoint_interval: 4
save_overlay: True
Expand All @@ -40,6 +40,7 @@ debug:
env_pool: False
load_optimizer_state: False
async_wrapper: False
sqlite_wrapper: True
archive_states: False

env:
Expand Down Expand Up @@ -130,6 +131,7 @@ train:
load_optimizer_state: False
use_rnn: True
async_wrapper: True
sqlite_wrapper: True
archive_states: True
swarm: True

Expand Down
41 changes: 29 additions & 12 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class CleanPuffeRL:
policy: nn.Module
env_send_queues: list[Queue]
env_recv_queues: list[Queue]
sqlite_db: str | None
wandb_client: wandb.wandb_sdk.wandb_run.Run | None = None
profile: Profile = field(default_factory=lambda: Profile())
losses: Losses = field(default_factory=lambda: Losses())
Expand Down Expand Up @@ -205,9 +206,9 @@ def __post_init__(self):
self.archive_path.mkdir(exist_ok=False)
print(f"Will archive states to {self.archive_path}")

self.conn = sqlite3.connect("states.db")
self.cur = self.conn.cursor()
self.cur.execute("CREATE TABLE states(env_id INT PRIMARY_KEY, state TEXT)")
if self.sqlite_db:
self.conn = sqlite3.connect(self.sqlite_db)
self.cur = self.conn.cursor()

@pufferlib.utils.profile
def evaluate(self):
Expand Down Expand Up @@ -293,7 +294,7 @@ def evaluate(self):
# progressing
# env id in async queues is the index within self.infos - self.config.num_envs + 1
if (
self.config.async_wrapper
(self.config.async_wrapper or self.config.sqlite_wrapper)
and hasattr(self.config, "swarm")
and self.config.swarm
and "required_count" in self.infos
Expand Down Expand Up @@ -344,18 +345,34 @@ def evaluate(self):
# Until then env ids are 1-indexed
print(f"\tNew events ({len(new_state_key)}): {new_state_key}")
new_states = [
"({state})"
state
for state in random.choices(
self.states[new_state_key], k=len(self.event_tracker.keys())
)
]
self.cur.execute(
"INSERT INTO states(state) VALUES "
f"{','.join(new_states)} "
"ON CONFLICT(env_id) "
"DO UPDATE SET state=EXCLUDED.state;"
)
self.vecenv.async_reset()
if self.sqlite_db:
self.cur.executemany(
"""
UPDATE states
SET state=?
SET reset=?
WHERE env_id=?
""",
tuple(
[
(state, True, env_id)
for state, env_id in zip(new_states, self.event_tracker.keys())
]
),
)
self.vecenv.async_reset()
if self.config.async_wrapper:
for key, state in zip(self.event_tracker.keys(), new_states):
self.env_recv_queues[key].put(state)
for key in self.event_tracker.keys():
# print(f"\tWaiting for message from env-id {key}")
self.env_send_queues[key].get()

print(
f"State migration to {self.archive_path}/{str(hash(new_state_key))} complete"
)
Expand Down
19 changes: 6 additions & 13 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import random
from collections import deque
from pathlib import Path
import sqlite3
from typing import Any, Iterable, Optional
import uuid

Expand Down Expand Up @@ -173,8 +172,6 @@ def __init__(self, env_config: DictConfig):
self.map_frame_writer = None
self.reset_count = 0
self.all_runs = []
self.conn = sqlite3.connect("states.db")
self.cur = self.conn.cursor()

# Set this in SOME subclasses
self.metadata = {"render.modes": []}
Expand Down Expand Up @@ -307,7 +304,6 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] =

infos = {}
self.explore_map_dim = 384
# res = self.cur.execute(f"SELECT state FROM states WHERE env_id={self.env_id}")
if self.first or options.get("state", None) is not None:
# We only init seen hidden objs once cause they can only be found once!
if options.get("state", None) is not None:
Expand Down Expand Up @@ -733,6 +729,7 @@ def step(self, action):
self.step_count += 1

# cut mon check
reset = False
if not self.party_has_cut_capable_mon():
reset = True
self.first = True
Expand Down Expand Up @@ -1590,16 +1587,12 @@ def get_game_state_reward(self):

def update_max_op_level(self):
# opp_base_level = 5
opponent_level = (
max(
[
self.read_m(f"wEnemyMon{i+1}Level")
for i in range(self.read_m("wEnemyPartyCount"))
]
+ [0]
)
# - opp_base_level
opponent_level = max(
[0]
+ [self.read_m(f"wEnemyMon{i+1}Level") for i in range(self.read_m("wEnemyPartyCount"))]
)
# - opp_base_level

self.max_opponent_level = max(0, self.max_opponent_level, opponent_level)
return self.max_opponent_level

Expand Down
113 changes: 73 additions & 40 deletions pokemonred_puffer/train.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import functools
import importlib
import os
import sqlite3
from tempfile import NamedTemporaryFile
import uuid
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from enum import Enum
from multiprocessing import Queue
from pathlib import Path
Expand All @@ -21,6 +23,7 @@
from pokemonred_puffer.cleanrl_puffer import CleanPuffeRL
from pokemonred_puffer.environment import RedGymEnv
from pokemonred_puffer.wrappers.async_io import AsyncWrapper
from pokemonred_puffer.wrappers.sqlite import SqliteStateResetWrapper

app = typer.Typer(pretty_exceptions_enable=False)

Expand Down Expand Up @@ -62,26 +65,30 @@ def load_from_config(config: DictConfig, debug: bool) -> DictConfig:
def make_env_creator(
wrapper_classes: list[tuple[str, ModuleType]],
reward_class: RedGymEnv,
async_wrapper: bool = True,
async_wrapper: bool = False,
sqlite_wrapper: bool = False,
) -> Callable[[DictConfig, DictConfig], pufferlib.emulation.GymnasiumPufferEnv]:
def env_creator(
env_config: DictConfig,
wrappers_config: list[dict[str, Any]],
reward_config: DictConfig,
async_config: dict[str, Queue] | None = None,
sqlite_config: dict[str, str] | None = None,
) -> pufferlib.emulation.GymnasiumPufferEnv:
env = reward_class(env_config, reward_config)
for cfg, (_, wrapper_class) in zip(wrappers_config, wrapper_classes):
env = wrapper_class(env, OmegaConf.create([x for x in cfg.values()][0]))
if async_wrapper and async_config:
env = AsyncWrapper(env, async_config["send_queues"], async_config["recv_queues"])
if sqlite_wrapper and sqlite_config:
env = SqliteStateResetWrapper(env, sqlite_config["database"])
return pufferlib.emulation.GymnasiumPufferEnv(env=env)

return env_creator


def setup_agent(
wrappers: list[str], reward_name: str, async_wrapper: bool = True
wrappers: list[str], reward_name: str, async_wrapper: bool = False, sqlite_wrapper: bool = False
) -> Callable[[DictConfig, DictConfig], pufferlib.emulation.GymnasiumPufferEnv]:
# TODO: Make this less dependent on the name of this repo and its file structure
wrapper_classes = [
Expand All @@ -100,7 +107,7 @@ def setup_agent(
importlib.import_module(f"pokemonred_puffer.rewards.{reward_module}"), reward_class_name
)
# NOTE: This assumes reward_module has RewardWrapper(RedGymEnv) class
env_creator = make_env_creator(wrapper_classes, reward_class, async_wrapper)
env_creator = make_env_creator(wrapper_classes, reward_class, async_wrapper, sqlite_wrapper)

return env_creator

Expand Down Expand Up @@ -159,7 +166,10 @@ def setup(
config.vectorization = Vectorization.serial

async_wrapper = config.train.get("async_wrapper", False)
env_creator = setup_agent(config.wrappers[wrappers_name], reward_name, async_wrapper)
sqlite_wrapper = config.train.get("sqlite_wrapper", False)
env_creator = setup_agent(
config.wrappers[wrappers_name], reward_name, async_wrapper, sqlite_wrapper
)
return config, env_creator


Expand Down Expand Up @@ -335,41 +345,64 @@ def train(
vec = pufferlib.vector.Multiprocessing

# TODO: Remove the +1 once the driver env doesn't permanently increase the env id
env_send_queues = [Queue() for _ in range(2 * config.train.num_envs + 1)]
env_recv_queues = [Queue() for _ in range(2 * config.train.num_envs + 1)]

vecenv = pufferlib.vector.make(
env_creator,
env_kwargs={
"env_config": config.env,
"wrappers_config": config.wrappers[wrappers_name],
"reward_config": config.rewards[reward_name]["reward"],
"async_config": {"send_queues": env_send_queues, "recv_queues": env_recv_queues},
},
num_envs=config.train.num_envs,
num_workers=config.train.num_workers,
batch_size=config.train.env_batch_size,
zero_copy=config.train.zero_copy,
backend=vec,
)
policy = make_policy(vecenv.driver_env, policy_name, config)

config.train.env = "Pokemon Red"
trainer = CleanPuffeRL(
exp_name=exp_name,
config=config.train,
vecenv=vecenv,
policy=policy,
env_recv_queues=env_recv_queues,
env_send_queues=env_send_queues,
wandb_client=wandb_client,
)
while not trainer.done_training():
trainer.evaluate()
trainer.train()

trainer.close()
print("Done training")
env_send_queues = []
env_recv_queues = []
if config.train.get("async_wrapper", False):
env_send_queues = [Queue() for _ in range(2 * config.train.num_envs + 1)]
env_recv_queues = [Queue() for _ in range(2 * config.train.num_envs + 1)]

sqlite_context = nullcontext
if config.train.get("sqlite_wrapper", False):
sqlite_context = NamedTemporaryFile

with sqlite_context() as sqlite_db:
db_filename = None
if config.train.get("sqlite_wrapper", False):
db_filename = sqlite_db.name
conn = sqlite3.connect(db_filename)
cur = conn.cursor()
cur.execute(
"CREATE TABLE states(env_id INT PRIMARY_KEY, pyboy_state BLOB, reset BOOLEAN);"
)
cur.close()

vecenv = pufferlib.vector.make(
env_creator,
env_kwargs={
"env_config": config.env,
"wrappers_config": config.wrappers[wrappers_name],
"reward_config": config.rewards[reward_name]["reward"],
"async_config": {
"send_queues": env_send_queues,
"recv_queues": env_recv_queues,
},
"sqlite_config": {"database": db_filename},
},
num_envs=config.train.num_envs,
num_workers=config.train.num_workers,
batch_size=config.train.env_batch_size,
zero_copy=config.train.zero_copy,
backend=vec,
)
policy = make_policy(vecenv.driver_env, policy_name, config)

config.train.env = "Pokemon Red"
trainer = CleanPuffeRL(
exp_name=exp_name,
config=config.train,
vecenv=vecenv,
policy=policy,
env_recv_queues=env_recv_queues,
env_send_queues=env_send_queues,
sqlite_db=db_filename,
wandb_client=wandb_client,
)
while not trainer.done_training():
trainer.evaluate()
trainer.train()

trainer.close()
print("Done training")


if __name__ == "__main__":
Expand Down
50 changes: 50 additions & 0 deletions pokemonred_puffer/wrappers/sqlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from os import PathLike
import sqlite3
from typing import Any

import gymnasium as gym

from pokemonred_puffer.environment import RedGymEnv


class SqliteStateResetWrapper(gym.Wrapper):
def __init__(
self,
env: RedGymEnv,
database: str | bytes | PathLike[str] | PathLike[bytes],
):
super().__init__(env)
self.conn = sqlite3.connect(database)
self.cur = self.conn.cursor()
self.cur.execute(
"""
INSERT INTO states(env_id, pyboy_state, reset)
VALUES(?, ?, ?)
""",
(self.env.unwrapped.env_id, b"", False),
)

def reset(self, seed: int | None = None, options: dict[str, Any] | None = None):
reset, pyboy_state = self.cur.execute(
"""
SELECT reset, pyboy_state
FROM states
WHERE env_id = ?
""",
(self.env.unwrapped.env_id,),
).fetchone()
if reset:
if options:
options["state"] = pyboy_state
else:
options = {"state": pyboy_state}
res = self.env.reset(seed=seed, options=options)
self.cur.execute(
"""
UPDATE states
SET reset = False
WHERE env_id = ?
""",
(self.env.unwrapped.env_id,),
)
return res

0 comments on commit b6d6179

Please sign in to comment.