From b6d6179027ee1361bc83f8ffc5c1c21d9e9009f2 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Thu, 17 Oct 2024 21:55:31 -0400 Subject: [PATCH] Add sqlite wrapper --- config.yaml | 12 +-- pokemonred_puffer/cleanrl_puffer.py | 41 +++++++--- pokemonred_puffer/environment.py | 19 ++--- pokemonred_puffer/train.py | 113 +++++++++++++++++---------- pokemonred_puffer/wrappers/sqlite.py | 50 ++++++++++++ 5 files changed, 165 insertions(+), 70 deletions(-) create mode 100644 pokemonred_puffer/wrappers/sqlite.py diff --git a/config.yaml b/config.yaml index cb01e79..15a67fe 100644 --- a/config.yaml +++ b/config.yaml @@ -5,7 +5,7 @@ wandb: debug: env: - headless: True + headless: False stream_wrapper: False init_state: "victory_road_5" state_dir: pyboy_states @@ -25,13 +25,13 @@ debug: num_envs: 1 envs_per_worker: 1 num_workers: 1 - env_batch_size: 4 + env_batch_size: 128 zero_copy: False - batch_size: 4 - minibatch_size: 4 + batch_size: 1024 + minibatch_size: 128 batch_rows: 4 bptt_horizon: 2 - total_timesteps: 16 + total_timesteps: 1_000_000 save_checkpoint: True checkpoint_interval: 4 save_overlay: True @@ -40,6 +40,7 @@ debug: env_pool: False load_optimizer_state: False async_wrapper: False + sqlite_wrapper: True archive_states: False env: @@ -130,6 +131,7 @@ train: load_optimizer_state: False use_rnn: True async_wrapper: True + sqlite_wrapper: True archive_states: True swarm: True diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 2157a39..e1eca8f 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -130,6 +130,7 @@ class CleanPuffeRL: policy: nn.Module env_send_queues: list[Queue] env_recv_queues: list[Queue] + sqlite_db: str | None wandb_client: wandb.wandb_sdk.wandb_run.Run | None = None profile: Profile = field(default_factory=lambda: Profile()) losses: Losses = field(default_factory=lambda: Losses()) @@ -205,9 +206,9 @@ def __post_init__(self): self.archive_path.mkdir(exist_ok=False) print(f"Will archive states to {self.archive_path}") - self.conn = sqlite3.connect("states.db") - self.cur = self.conn.cursor() - self.cur.execute("CREATE TABLE states(env_id INT PRIMARY_KEY, state TEXT)") + if self.sqlite_db: + self.conn = sqlite3.connect(self.sqlite_db) + self.cur = self.conn.cursor() @pufferlib.utils.profile def evaluate(self): @@ -293,7 +294,7 @@ def evaluate(self): # progressing # env id in async queues is the index within self.infos - self.config.num_envs + 1 if ( - self.config.async_wrapper + (self.config.async_wrapper or self.config.sqlite_wrapper) and hasattr(self.config, "swarm") and self.config.swarm and "required_count" in self.infos @@ -344,18 +345,34 @@ def evaluate(self): # Until then env ids are 1-indexed print(f"\tNew events ({len(new_state_key)}): {new_state_key}") new_states = [ - "({state})" + state for state in random.choices( self.states[new_state_key], k=len(self.event_tracker.keys()) ) ] - self.cur.execute( - "INSERT INTO states(state) VALUES " - f"{','.join(new_states)} " - "ON CONFLICT(env_id) " - "DO UPDATE SET state=EXCLUDED.state;" - ) - self.vecenv.async_reset() + if self.sqlite_db: + self.cur.executemany( + """ + UPDATE states + SET state=? + SET reset=? + WHERE env_id=? + """, + tuple( + [ + (state, True, env_id) + for state, env_id in zip(new_states, self.event_tracker.keys()) + ] + ), + ) + self.vecenv.async_reset() + if self.config.async_wrapper: + for key, state in zip(self.event_tracker.keys(), new_states): + self.env_recv_queues[key].put(state) + for key in self.event_tracker.keys(): + # print(f"\tWaiting for message from env-id {key}") + self.env_send_queues[key].get() + print( f"State migration to {self.archive_path}/{str(hash(new_state_key))} complete" ) diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 1ff8cc4..f8d3bdf 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -5,7 +5,6 @@ import random from collections import deque from pathlib import Path -import sqlite3 from typing import Any, Iterable, Optional import uuid @@ -173,8 +172,6 @@ def __init__(self, env_config: DictConfig): self.map_frame_writer = None self.reset_count = 0 self.all_runs = [] - self.conn = sqlite3.connect("states.db") - self.cur = self.conn.cursor() # Set this in SOME subclasses self.metadata = {"render.modes": []} @@ -307,7 +304,6 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = infos = {} self.explore_map_dim = 384 - # res = self.cur.execute(f"SELECT state FROM states WHERE env_id={self.env_id}") if self.first or options.get("state", None) is not None: # We only init seen hidden objs once cause they can only be found once! if options.get("state", None) is not None: @@ -733,6 +729,7 @@ def step(self, action): self.step_count += 1 # cut mon check + reset = False if not self.party_has_cut_capable_mon(): reset = True self.first = True @@ -1590,16 +1587,12 @@ def get_game_state_reward(self): def update_max_op_level(self): # opp_base_level = 5 - opponent_level = ( - max( - [ - self.read_m(f"wEnemyMon{i+1}Level") - for i in range(self.read_m("wEnemyPartyCount")) - ] - + [0] - ) - # - opp_base_level + opponent_level = max( + [0] + + [self.read_m(f"wEnemyMon{i+1}Level") for i in range(self.read_m("wEnemyPartyCount"))] ) + # - opp_base_level + self.max_opponent_level = max(0, self.max_opponent_level, opponent_level) return self.max_opponent_level diff --git a/pokemonred_puffer/train.py b/pokemonred_puffer/train.py index 4b8741b..215bff3 100644 --- a/pokemonred_puffer/train.py +++ b/pokemonred_puffer/train.py @@ -1,8 +1,10 @@ import functools import importlib import os +import sqlite3 +from tempfile import NamedTemporaryFile import uuid -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from enum import Enum from multiprocessing import Queue from pathlib import Path @@ -21,6 +23,7 @@ from pokemonred_puffer.cleanrl_puffer import CleanPuffeRL from pokemonred_puffer.environment import RedGymEnv from pokemonred_puffer.wrappers.async_io import AsyncWrapper +from pokemonred_puffer.wrappers.sqlite import SqliteStateResetWrapper app = typer.Typer(pretty_exceptions_enable=False) @@ -62,26 +65,30 @@ def load_from_config(config: DictConfig, debug: bool) -> DictConfig: def make_env_creator( wrapper_classes: list[tuple[str, ModuleType]], reward_class: RedGymEnv, - async_wrapper: bool = True, + async_wrapper: bool = False, + sqlite_wrapper: bool = False, ) -> Callable[[DictConfig, DictConfig], pufferlib.emulation.GymnasiumPufferEnv]: def env_creator( env_config: DictConfig, wrappers_config: list[dict[str, Any]], reward_config: DictConfig, async_config: dict[str, Queue] | None = None, + sqlite_config: dict[str, str] | None = None, ) -> pufferlib.emulation.GymnasiumPufferEnv: env = reward_class(env_config, reward_config) for cfg, (_, wrapper_class) in zip(wrappers_config, wrapper_classes): env = wrapper_class(env, OmegaConf.create([x for x in cfg.values()][0])) if async_wrapper and async_config: env = AsyncWrapper(env, async_config["send_queues"], async_config["recv_queues"]) + if sqlite_wrapper and sqlite_config: + env = SqliteStateResetWrapper(env, sqlite_config["database"]) return pufferlib.emulation.GymnasiumPufferEnv(env=env) return env_creator def setup_agent( - wrappers: list[str], reward_name: str, async_wrapper: bool = True + wrappers: list[str], reward_name: str, async_wrapper: bool = False, sqlite_wrapper: bool = False ) -> Callable[[DictConfig, DictConfig], pufferlib.emulation.GymnasiumPufferEnv]: # TODO: Make this less dependent on the name of this repo and its file structure wrapper_classes = [ @@ -100,7 +107,7 @@ def setup_agent( importlib.import_module(f"pokemonred_puffer.rewards.{reward_module}"), reward_class_name ) # NOTE: This assumes reward_module has RewardWrapper(RedGymEnv) class - env_creator = make_env_creator(wrapper_classes, reward_class, async_wrapper) + env_creator = make_env_creator(wrapper_classes, reward_class, async_wrapper, sqlite_wrapper) return env_creator @@ -159,7 +166,10 @@ def setup( config.vectorization = Vectorization.serial async_wrapper = config.train.get("async_wrapper", False) - env_creator = setup_agent(config.wrappers[wrappers_name], reward_name, async_wrapper) + sqlite_wrapper = config.train.get("sqlite_wrapper", False) + env_creator = setup_agent( + config.wrappers[wrappers_name], reward_name, async_wrapper, sqlite_wrapper + ) return config, env_creator @@ -335,41 +345,64 @@ def train( vec = pufferlib.vector.Multiprocessing # TODO: Remove the +1 once the driver env doesn't permanently increase the env id - env_send_queues = [Queue() for _ in range(2 * config.train.num_envs + 1)] - env_recv_queues = [Queue() for _ in range(2 * config.train.num_envs + 1)] - - vecenv = pufferlib.vector.make( - env_creator, - env_kwargs={ - "env_config": config.env, - "wrappers_config": config.wrappers[wrappers_name], - "reward_config": config.rewards[reward_name]["reward"], - "async_config": {"send_queues": env_send_queues, "recv_queues": env_recv_queues}, - }, - num_envs=config.train.num_envs, - num_workers=config.train.num_workers, - batch_size=config.train.env_batch_size, - zero_copy=config.train.zero_copy, - backend=vec, - ) - policy = make_policy(vecenv.driver_env, policy_name, config) - - config.train.env = "Pokemon Red" - trainer = CleanPuffeRL( - exp_name=exp_name, - config=config.train, - vecenv=vecenv, - policy=policy, - env_recv_queues=env_recv_queues, - env_send_queues=env_send_queues, - wandb_client=wandb_client, - ) - while not trainer.done_training(): - trainer.evaluate() - trainer.train() - - trainer.close() - print("Done training") + env_send_queues = [] + env_recv_queues = [] + if config.train.get("async_wrapper", False): + env_send_queues = [Queue() for _ in range(2 * config.train.num_envs + 1)] + env_recv_queues = [Queue() for _ in range(2 * config.train.num_envs + 1)] + + sqlite_context = nullcontext + if config.train.get("sqlite_wrapper", False): + sqlite_context = NamedTemporaryFile + + with sqlite_context() as sqlite_db: + db_filename = None + if config.train.get("sqlite_wrapper", False): + db_filename = sqlite_db.name + conn = sqlite3.connect(db_filename) + cur = conn.cursor() + cur.execute( + "CREATE TABLE states(env_id INT PRIMARY_KEY, pyboy_state BLOB, reset BOOLEAN);" + ) + cur.close() + + vecenv = pufferlib.vector.make( + env_creator, + env_kwargs={ + "env_config": config.env, + "wrappers_config": config.wrappers[wrappers_name], + "reward_config": config.rewards[reward_name]["reward"], + "async_config": { + "send_queues": env_send_queues, + "recv_queues": env_recv_queues, + }, + "sqlite_config": {"database": db_filename}, + }, + num_envs=config.train.num_envs, + num_workers=config.train.num_workers, + batch_size=config.train.env_batch_size, + zero_copy=config.train.zero_copy, + backend=vec, + ) + policy = make_policy(vecenv.driver_env, policy_name, config) + + config.train.env = "Pokemon Red" + trainer = CleanPuffeRL( + exp_name=exp_name, + config=config.train, + vecenv=vecenv, + policy=policy, + env_recv_queues=env_recv_queues, + env_send_queues=env_send_queues, + sqlite_db=db_filename, + wandb_client=wandb_client, + ) + while not trainer.done_training(): + trainer.evaluate() + trainer.train() + + trainer.close() + print("Done training") if __name__ == "__main__": diff --git a/pokemonred_puffer/wrappers/sqlite.py b/pokemonred_puffer/wrappers/sqlite.py new file mode 100644 index 0000000..a6ff7e8 --- /dev/null +++ b/pokemonred_puffer/wrappers/sqlite.py @@ -0,0 +1,50 @@ +from os import PathLike +import sqlite3 +from typing import Any + +import gymnasium as gym + +from pokemonred_puffer.environment import RedGymEnv + + +class SqliteStateResetWrapper(gym.Wrapper): + def __init__( + self, + env: RedGymEnv, + database: str | bytes | PathLike[str] | PathLike[bytes], + ): + super().__init__(env) + self.conn = sqlite3.connect(database) + self.cur = self.conn.cursor() + self.cur.execute( + """ + INSERT INTO states(env_id, pyboy_state, reset) + VALUES(?, ?, ?) + """, + (self.env.unwrapped.env_id, b"", False), + ) + + def reset(self, seed: int | None = None, options: dict[str, Any] | None = None): + reset, pyboy_state = self.cur.execute( + """ + SELECT reset, pyboy_state + FROM states + WHERE env_id = ? + """, + (self.env.unwrapped.env_id,), + ).fetchone() + if reset: + if options: + options["state"] = pyboy_state + else: + options = {"state": pyboy_state} + res = self.env.reset(seed=seed, options=options) + self.cur.execute( + """ + UPDATE states + SET reset = False + WHERE env_id = ? + """, + (self.env.unwrapped.env_id,), + ) + return res