diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 4438cfd..2157a39 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -5,6 +5,7 @@ import os import pathlib import random +import sqlite3 import time from collections import defaultdict, deque from dataclasses import dataclass, field @@ -204,6 +205,10 @@ def __post_init__(self): self.archive_path.mkdir(exist_ok=False) print(f"Will archive states to {self.archive_path}") + self.conn = sqlite3.connect("states.db") + self.cur = self.conn.cursor() + self.cur.execute("CREATE TABLE states(env_id INT PRIMARY_KEY, state TEXT)") + @pufferlib.utils.profile def evaluate(self): # states are managed separately so dont worry about deleting them @@ -338,17 +343,19 @@ def evaluate(self): # Need a way not to reset the env id counter for the driver env # Until then env ids are 1-indexed print(f"\tNew events ({len(new_state_key)}): {new_state_key}") - for key in self.event_tracker.keys(): - new_state = random.choice(self.states[new_state_key]) - - self.env_recv_queues[key].put(new_state) - # Now copy the hidden state over - # This may be a little slow, but so is this whole process - # self.next_lstm_state[0][:, i, :] = self.next_lstm_state[0][:, new_state, :] - # self.next_lstm_state[1][:, i, :] = self.next_lstm_state[1][:, new_state, :] - for key in self.event_tracker.keys(): - # print(f"\tWaiting for message from env-id {key}") - self.env_send_queues[key].get() + new_states = [ + "({state})" + for state in random.choices( + self.states[new_state_key], k=len(self.event_tracker.keys()) + ) + ] + self.cur.execute( + "INSERT INTO states(state) VALUES " + f"{','.join(new_states)} " + "ON CONFLICT(env_id) " + "DO UPDATE SET state=EXCLUDED.state;" + ) + self.vecenv.async_reset() print( f"State migration to {self.archive_path}/{str(hash(new_state_key))} complete" ) diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 579a880..1ff8cc4 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -5,6 +5,7 @@ import random from collections import deque from pathlib import Path +import sqlite3 from typing import Any, Iterable, Optional import uuid @@ -172,6 +173,8 @@ def __init__(self, env_config: DictConfig): self.map_frame_writer = None self.reset_count = 0 self.all_runs = [] + self.conn = sqlite3.connect("states.db") + self.cur = self.conn.cursor() # Set this in SOME subclasses self.metadata = {"render.modes": []} @@ -280,6 +283,7 @@ def register_hooks(self): self.sign_hook, None, ) + self.reset_count = 0 def setup_disable_wild_encounters(self): bank, addr = self.pyboy.symbol_lookup("TryDoWildEncounter.gotWildEncounterType") @@ -303,33 +307,18 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = infos = {} self.explore_map_dim = 384 + # res = self.cur.execute(f"SELECT state FROM states WHERE env_id={self.env_id}") if self.first or options.get("state", None) is not None: - self.recent_screens = deque() - self.recent_actions = deque() # We only init seen hidden objs once cause they can only be found once! - self.a_press = set() if options.get("state", None) is not None: self.pyboy.load_state(io.BytesIO(options["state"])) - self.reset_count += 1 else: with open(self.init_state_path, "rb") as f: self.pyboy.load_state(f) - self.reset_count = 0 self.base_event_flags = sum( self.read_m(i).bit_count() for i in range(EVENT_FLAGS_START, EVENT_FLAGS_START + EVENTS_FLAGS_LENGTH) ) - # A bit of duplicate code. Blah. - self.events = EventFlags(self.pyboy) - self.missables = MissableFlags(self.pyboy) - self.flags = Flags(self.pyboy) - self.party = PartyMons(self.pyboy) - self.required_events = self.get_required_events() - self.required_items = self.get_required_items() - self.seen_pokemon = np.zeros(152, dtype=np.uint8) - self.caught_pokemon = np.zeros(152, dtype=np.uint8) - self.moves_obtained = np.zeros(0xA5, dtype=np.uint8) - self.pokecenters = np.zeros(252, dtype=np.uint8) if self.save_state: state = io.BytesIO() @@ -348,24 +337,27 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = # if not seed: # seed = random.randint(0, 4096) # self.pyboy.tick(seed, render=False) - else: - self.reset_count += 1 - - self.recent_screens.clear() - self.recent_actions.clear() - self.a_press.clear() - self.seen_pokemon.fill(0) - self.caught_pokemon.fill(0) - self.moves_obtained.fill(0) - self.explore_map *= 0 - self.reward_explore_map *= 0 - self.cut_explore_map *= 0 - self.reset_mem() + self.reset_count += 1 self.events = EventFlags(self.pyboy) self.missables = MissableFlags(self.pyboy) self.flags = Flags(self.pyboy) self.party = PartyMons(self.pyboy) + self.required_events = self.get_required_events() + self.required_items = self.get_required_items() + self.seen_pokemon = np.zeros(152, dtype=np.uint8) + self.caught_pokemon = np.zeros(152, dtype=np.uint8) + self.moves_obtained = np.zeros(0xA5, dtype=np.uint8) + self.pokecenters = np.zeros(252, dtype=np.uint8) + + self.recent_screens = deque() + self.recent_actions = deque() + self.a_press = set() + self.explore_map *= 0 + self.reward_explore_map *= 0 + self.cut_explore_map *= 0 + self.reset_mem() + self.update_pokedex() self.update_tm_hm_moves_obtained() self.party_size = self.read_m("wPartyCount") @@ -375,8 +367,6 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.levels_satisfied = False self.base_explore = 0 self.max_opponent_level = 0 - self.required_events = self.get_required_events() - self.required_items = self.get_required_items() self.max_level_rew = 0 self.max_level_sum = 0 self.last_health = 1 @@ -741,11 +731,6 @@ def step(self, action): obs = self._get_obs() self.step_count += 1 - reset = ( - self.step_count >= self.get_max_steps() - # or - # self.caught_pokemon[6] == 1 # squirtle - ) # cut mon check if not self.party_has_cut_capable_mon(): diff --git a/pokemonred_puffer/wrappers/exploration.py b/pokemonred_puffer/wrappers/exploration.py index 8824e2e..86cdcf4 100644 --- a/pokemonred_puffer/wrappers/exploration.py +++ b/pokemonred_puffer/wrappers/exploration.py @@ -47,9 +47,6 @@ def step(self, action): return self.env.step(action) - def reset(self, *args, **kwargs): - return self.env.reset(*args, **kwargs) - def step_forget_explore(self): self.env.unwrapped.seen_coords.update( (k, max(0.15, v * (self.step_forgetting_factor["coords"]))) @@ -94,6 +91,9 @@ def __init__(self, env: RedGymEnv, reward_config: DictConfig): self.cache = LRUCache(capacity=self.capacity) def step(self, action): + if self.env.unwrapped.step_count >= self.env.unwrapped.get_max_steps(): + self.cache.clear() + step = self.env.step(action) player_x, player_y, map_n = self.env.unwrapped.get_game_coords() # Walrus operator does not support tuple unpacking @@ -103,10 +103,6 @@ def step(self, action): self.env.unwrapped.explore_map[local_to_global(y, x, n)] = 0 return step - def reset(self, *args, **kwargs): - self.cache.clear() - return self.env.reset(*args, **kwargs) - class OnResetExplorationWrapper(gym.Wrapper): def __init__(self, env: RedGymEnv, reward_config: DictConfig): @@ -115,21 +111,22 @@ def __init__(self, env: RedGymEnv, reward_config: DictConfig): self.jitter = reward_config.jitter self.counter = 0 - def reset(self, *args, **kwargs): - if (self.counter + random.randint(0, self.jitter)) >= self.full_reset_frequency: - self.counter = 0 - self.env.unwrapped.explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32) - self.env.unwrapped.cut_explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32) - self.env.unwrapped.seen_coords.clear() - self.env.unwrapped.seen_map_ids *= 0 - self.env.unwrapped.seen_npcs.clear() - self.env.unwrapped.cut_coords.clear() - self.env.unwrapped.cut_tiles.clear() - self.env.unwrapped.seen_warps.clear() - self.env.unwrapped.seen_hidden_objs.clear() - self.env.unwrapped.seen_signs.clear() - self.counter += 1 - return self.env.reset(*args, **kwargs) + def step(self, action): + if self.env.unwrapped.step_count >= self.env.unwrapped.get_max_steps(): + if (self.counter + random.randint(0, self.jitter)) >= self.full_reset_frequency: + self.counter = 0 + self.env.unwrapped.explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32) + self.env.unwrapped.cut_explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32) + self.env.unwrapped.seen_coords.clear() + self.env.unwrapped.seen_map_ids *= 0 + self.env.unwrapped.seen_npcs.clear() + self.env.unwrapped.cut_coords.clear() + self.env.unwrapped.cut_tiles.clear() + self.env.unwrapped.seen_warps.clear() + self.env.unwrapped.seen_hidden_objs.clear() + self.env.unwrapped.seen_signs.clear() + self.counter += 1 + return self.env.step(action) class OnResetLowerToFixedValueWrapper(gym.Wrapper): @@ -137,41 +134,50 @@ def __init__(self, env: RedGymEnv, reward_config: DictConfig): super().__init__(env) self.fixed_value = reward_config.fixed_value - def reset(self, *args, **kwargs): - self.env.unwrapped.seen_coords.update( - (k, self.fixed_value["coords"]) - for k, v in self.env.unwrapped.seen_coords.items() - if v > 0 - ) - self.env.unwrapped.seen_map_ids[self.env.unwrapped.seen_map_ids > 0] = self.fixed_value[ - "map_ids" - ] - self.env.unwrapped.seen_npcs.update( - (k, self.fixed_value["npc"]) for k, v in self.env.unwrapped.seen_npcs.items() if v > 0 - ) - self.env.unwrapped.cut_tiles.update( - (k, self.fixed_value["cut"]) for k, v in self.env.unwrapped.seen_npcs.items() if v > 0 - ) - self.env.unwrapped.cut_coords.update( - (k, self.fixed_value["cut"]) for k, v in self.env.unwrapped.seen_npcs.items() if v > 0 - ) - self.env.unwrapped.explore_map[self.env.unwrapped.explore_map > 0] = self.fixed_value[ - "explore" - ] - self.env.unwrapped.cut_explore_map[self.env.unwrapped.cut_explore_map > 0] = ( - self.fixed_value["cut"] - ) - self.env.unwrapped.seen_warps.update( - (k, self.fixed_value["coords"]) - for k, v in self.env.unwrapped.seen_warps.items() - if v > 0 - ) - self.env.unwrapped.seen_hidden_objs.update( - (k, self.fixed_value["hidden_objs"]) - for k, v in self.env.unwrapped.seen_npcs.items() - if v > 0 - ) - self.env.unwrapped.seen_signs.update( - (k, self.fixed_value["signs"]) for k, v in self.env.unwrapped.seen_npcs.items() if v > 0 - ) - return self.env.reset(*args, **kwargs) + def step(self, action): + if self.env.unwrapped.step_count >= self.env.unwrapped.get_max_steps(): + self.env.unwrapped.seen_coords.update( + (k, self.fixed_value["coords"]) + for k, v in self.env.unwrapped.seen_coords.items() + if v > 0 + ) + self.env.unwrapped.seen_map_ids[self.env.unwrapped.seen_map_ids > 0] = self.fixed_value[ + "map_ids" + ] + self.env.unwrapped.seen_npcs.update( + (k, self.fixed_value["npc"]) + for k, v in self.env.unwrapped.seen_npcs.items() + if v > 0 + ) + self.env.unwrapped.cut_tiles.update( + (k, self.fixed_value["cut"]) + for k, v in self.env.unwrapped.seen_npcs.items() + if v > 0 + ) + self.env.unwrapped.cut_coords.update( + (k, self.fixed_value["cut"]) + for k, v in self.env.unwrapped.seen_npcs.items() + if v > 0 + ) + self.env.unwrapped.explore_map[self.env.unwrapped.explore_map > 0] = self.fixed_value[ + "explore" + ] + self.env.unwrapped.cut_explore_map[self.env.unwrapped.cut_explore_map > 0] = ( + self.fixed_value["cut"] + ) + self.env.unwrapped.seen_warps.update( + (k, self.fixed_value["coords"]) + for k, v in self.env.unwrapped.seen_warps.items() + if v > 0 + ) + self.env.unwrapped.seen_hidden_objs.update( + (k, self.fixed_value["hidden_objs"]) + for k, v in self.env.unwrapped.seen_npcs.items() + if v > 0 + ) + self.env.unwrapped.seen_signs.update( + (k, self.fixed_value["signs"]) + for k, v in self.env.unwrapped.seen_npcs.items() + if v > 0 + ) + return self.env.unwrapped.step(action)