From e7688f2c63e8493e785e122226251be5d3e9625f Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Tue, 2 Jul 2024 23:11:10 -0400 Subject: [PATCH] something to test on a pufferbox --- config.yaml | 13 +-- pokemonred_puffer/cleanrl_puffer.py | 98 +++++++++++--------- pokemonred_puffer/environment.py | 47 +++++++++- pokemonred_puffer/wrappers/episode_stats.py | 3 + pokemonred_puffer/wrappers/stream_wrapper.py | 17 ---- 5 files changed, 105 insertions(+), 73 deletions(-) diff --git a/config.yaml b/config.yaml index 36a056f..99842f8 100644 --- a/config.yaml +++ b/config.yaml @@ -19,9 +19,10 @@ debug: device: cpu compile: False compile_mode: default - num_envs: 1 + num_envs: 16 + envs_per_worker: 1 num_workers: 1 - env_batch_size: 4 + env_batch_size: 32 env_pool: True zero_copy: False batch_size: 128 @@ -36,8 +37,8 @@ debug: verbose: False env_pool: False load_optimizer_state: False - # swarm_frequency: 10 - # swarm_keep_pct: .1 + swarm_frequency: 1 + swarm_keep_pct: .1 env: headless: True @@ -68,7 +69,7 @@ env: auto_pokeflute: True infinite_money: True use_global_map: False - save_state: False + save_state: True animate_scripts: False @@ -116,7 +117,7 @@ train: pool_kernel: [0] load_optimizer_state: False use_rnn: True - async_wrapper: False + async_wrapper: True # swarm_frequency: 500 # swarm_keep_pct: .8 diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 27143aa..76494c4 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -1,4 +1,5 @@ import argparse +from functools import partial import heapq import math import os @@ -138,6 +139,7 @@ class CleanPuffeRL: stats: dict = field(default_factory=lambda: {}) msg: str = "" infos: dict = field(default_factory=lambda: defaultdict(list)) + states: dict = field(default_factory=lambda: defaultdict(partial(deque, maxlen=5))) def __post_init__(self): seed_everything(self.config.seed, self.config.torch_deterministic) @@ -199,51 +201,9 @@ def __post_init__(self): @pufferlib.utils.profile def evaluate(self): - # Clear all self.infos except for the state + # states are managed separately so dont worry about deleting them for k in list(self.infos.keys()): - if k != "state": - del self.infos[k] - - # now for a tricky bit: - # if we have swarm_frequency, we will take the top swarm_keep_pct envs and evenly distribute - # their states to the bottom 90%. - # we do this here so the environment can remain "pure" - if ( - self.config.async_wrapper - and hasattr(self.config, "swarm_frequency") - and hasattr(self.config, "swarm_keep_pct") - and self.epoch % self.config.swarm_frequency == 0 - and "reward/event" in self.infos - and "state" in self.infos - ): - # collect the top swarm_keep_pct % of envs - largest = [ - x[0] - for x in heapq.nlargest( - math.ceil(self.config.num_envs * self.config.swarm_keep_pct), - enumerate(self.infos["reward/event"]), - key=lambda x: x[1], - ) - ] - print("Migrating states:") - waiting_for = [] - # Need a way not to reset the env id counter for the driver env - # Until then env ids are 1-indexed - for i in range(self.config.num_envs): - if i not in largest: - new_state = random.choice(largest) - print( - f'\t {i+1} -> {new_state+1}, event scores: {self.infos["reward/event"][i]} -> {self.infos["reward/event"][new_state]}' - ) - self.env_recv_queues[i + 1].put(self.infos["state"][new_state]) - waiting_for.append(i + 1) - # Now copy the hidden state over - # This may be a little slow, but so is this whole process - self.next_lstm_state[0][:, i, :] = self.next_lstm_state[0][:, new_state, :] - self.next_lstm_state[1][:, i, :] = self.next_lstm_state[1][:, new_state, :] - for i in waiting_for: - self.env_send_queues[i].get() - print("State migration complete") + del self.infos[k] with self.profile.eval_misc: policy = self.policy @@ -289,8 +249,9 @@ def evaluate(self): for i in info: for k, v in pufferlib.utils.unroll_nested_dict(i): - if k == "state": - self.infos[k] = [v] + if "state/" in k: + _, key = k.split("/") + self.states[key].append(v) else: self.infos[k].append(v) @@ -298,6 +259,51 @@ def evaluate(self): self.vecenv.send(actions) with self.profile.eval_misc: + # now for a tricky bit: + # if we have swarm_frequency, we will migrate the bottom + # % of envs in the batch (by required events count) + # and migrate them to a new state at random. + # Now this has a lot of gotchas and is really unstable + # E.g. Some envs could just constantly be on the bottom since they're never + # progressing + breakpoint() + if ( + self.config.async_wrapper + and hasattr(self.config, "swarm_frequency") + and hasattr(self.config, "swarm_keep_pct") + and self.epoch % self.config.swarm_frequency == 0 + and "required_events_count" in self.infos + and self.states + ): + # collect the top swarm_keep_pct % of the envs in the batch + largest = [ + x[0] + for x in heapq.nlargest( + math.ceil(self.config.num_envs * self.config.swarm_keep_pct), + enumerate(self.infos["required_events_count"]), + key=lambda x: x[1], + ) + ] + print("Migrating states:") + waiting_for = [] + # Need a way not to reset the env id counter for the driver env + # Until then env ids are 1-indexed + for i in range(self.config.num_envs): + if i not in largest: + new_state = random.choice(largest) + print( + f'\t {i+1} -> {new_state+1}, event scores: {self.infos["reward/event"][i]} -> {self.infos["reward/event"][new_state]}' + ) + self.env_recv_queues[i + 1].put(self.infos["state"][new_state]) + waiting_for.append(i + 1) + # Now copy the hidden state over + # This may be a little slow, but so is this whole process + self.next_lstm_state[0][:, i, :] = self.next_lstm_state[0][:, new_state, :] + self.next_lstm_state[1][:, i, :] = self.next_lstm_state[1][:, new_state, :] + for i in waiting_for: + self.env_send_queues[i].get() + print("State migration complete") + self.stats = {} for k, v in self.infos.items(): diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index d153518..5cf0546 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -1,5 +1,6 @@ from abc import abstractmethod import io +from multiprocessing import Lock, shared_memory import os import random from collections import deque @@ -78,6 +79,9 @@ # TODO: Make global map usage a configuration parameter class RedGymEnv(Env): + env_id = shared_memory.SharedMemory(create=True, size=4) + lock = Lock() + def __init__(self, env_config: pufferlib.namespace): # TODO: Dont use pufferlib.namespace. It seems to confuse __init__ self.video_dir = Path(env_config.video_dir) @@ -209,6 +213,19 @@ def __init__(self, env_config: pufferlib.namespace): self.first = True + with RedGymEnv.lock: + env_id = ( + (int(RedGymEnv.env_id.buf[0]) << 24) + + (int(RedGymEnv.env_id.buf[1]) << 16) + + (int(RedGymEnv.env_id.buf[2]) << 8) + + (int(RedGymEnv.env_id.buf[3])) + ) + self.env_id = env_id + env_id += 1 + RedGymEnv.env_id.buf[0] = (env_id >> 24) & 0xFF + RedGymEnv.env_id.buf[1] = (env_id >> 16) & 0xFF + RedGymEnv.env_id.buf[2] = (env_id >> 8) & 0xFF + RedGymEnv.env_id.buf[3] = (env_id) & 0xFF self.init_mem() def register_hooks(self): @@ -256,6 +273,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = # restart game, skipping credits options = options or {} + infos = {} self.explore_map_dim = 384 if self.first or options.get("state", None) is not None: self.recent_screens = deque() @@ -305,6 +323,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.base_explore = 0 self.max_opponent_level = 0 self.max_event_rew = 0 + self.required_events = self.get_required_events() self.max_level_rew = 0 self.max_level_sum = 0 self.last_health = 1 @@ -324,12 +343,15 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.total_reward = sum([val for _, val in self.progress_reward.items()]) self.first = False - infos = {} + if self.save_state: state = io.BytesIO() self.pyboy.save_state(state) state.seek(0) - infos |= {"state": state.read()} + infos |= { + "state": {hash("".join(self.required_events)): state.read()}, + "required_events_count": len(self.required_events), + } return self._get_obs(), infos def init_mem(self): @@ -611,11 +633,16 @@ def step(self, action): self.pokecenters[self.read_m("wLastBlackoutMap")] = 1 info = {} - if self.save_state and self.get_events_sum() > self.max_event_rew: + required_events = self.get_required_events() + new_required_events = required_events - self.required_events + if self.save_state and new_required_events: + breakpoint() state = io.BytesIO() self.pyboy.save_state(state) state.seek(0) - info["state"] = state.read() + info["state"] = {hash(required_events): state.read()} + info["required_events_count"] = len(required_events) + self.required_events = required_events # TODO: Make log frequency a configuration parameter if self.step_count % self.log_frequency == 0: @@ -1413,6 +1440,18 @@ def get_levels_reward(self): level_reward = 30 + (self.max_level_sum - 30) / 4 return level_reward + def get_required_events(self) -> set[str]: + return ( + {event for event in REQUIRED_EVENTS if self.events.get_event(event)} + | ({"rival3"} if (self.read_m("wSSAnne2FCurScript") == 4) else {}) + | ( + {"game_corner_rocket"} + if self.missables.get_missable("HS_GAME_CORNER_ROCKET") + else {} + ) + | ({"saffron_guard"} if self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK") else {}) + ) + def get_events_sum(self): # adds up all event flags, exclude museum ticket return max( diff --git a/pokemonred_puffer/wrappers/episode_stats.py b/pokemonred_puffer/wrappers/episode_stats.py index 7249e6a..a0e5143 100644 --- a/pokemonred_puffer/wrappers/episode_stats.py +++ b/pokemonred_puffer/wrappers/episode_stats.py @@ -22,6 +22,9 @@ def step(self, action): for k, v in pufferlib.utils.unroll_nested_dict(info): if "exploration_map" in k: self.info[k] = self.info.get(k, np.zeros_like(v)) + v + elif "state" in k: + breakpoint() + self.info["state"] |= v else: self.info[k] = v diff --git a/pokemonred_puffer/wrappers/stream_wrapper.py b/pokemonred_puffer/wrappers/stream_wrapper.py index 4440d36..ccbae3b 100644 --- a/pokemonred_puffer/wrappers/stream_wrapper.py +++ b/pokemonred_puffer/wrappers/stream_wrapper.py @@ -1,6 +1,5 @@ import asyncio import json -from multiprocessing import Lock, shared_memory import gymnasium as gym import websockets @@ -10,24 +9,8 @@ class StreamWrapper(gym.Wrapper): - env_id = shared_memory.SharedMemory(create=True, size=4) - lock = Lock() - def __init__(self, env: RedGymEnv, config: pufferlib.namespace): super().__init__(env) - with StreamWrapper.lock: - env_id = ( - (int(StreamWrapper.env_id.buf[0]) << 24) - + (int(StreamWrapper.env_id.buf[1]) << 16) - + (int(StreamWrapper.env_id.buf[2]) << 8) - + (int(StreamWrapper.env_id.buf[3])) - ) - self.env_id = env_id - env_id += 1 - StreamWrapper.env_id.buf[0] = (env_id >> 24) & 0xFF - StreamWrapper.env_id.buf[1] = (env_id >> 16) & 0xFF - StreamWrapper.env_id.buf[2] = (env_id >> 8) & 0xFF - StreamWrapper.env_id.buf[3] = (env_id) & 0xFF self.user = config.user self.ws_address = "wss://transdimensional.xyz/broadcast"