From 3cc3b2ee7d07e5f2d9cc30c5affaa9523dc31bc5 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Wed, 26 Jun 2024 09:35:02 -0400 Subject: [PATCH] Custom episode stats wrapper to minimize data passed through infos --- config.yaml | 3 +- pokemonred_puffer/environment.py | 34 +++++++----------- pokemonred_puffer/eval.py | 2 ++ pokemonred_puffer/train.py | 1 - pokemonred_puffer/wrappers/episode_stats.py | 36 ++++++++++++++++++++ pokemonred_puffer/wrappers/stream_wrapper.py | 20 +++++------ pyproject.toml | 1 + 7 files changed, 63 insertions(+), 34 deletions(-) create mode 100644 pokemonred_puffer/wrappers/episode_stats.py diff --git a/config.yaml b/config.yaml index 1b8cf16..8931420 100644 --- a/config.yaml +++ b/config.yaml @@ -22,7 +22,7 @@ debug: env_batch_size: 4 env_pool: True zero_copy: False - batch_size: 4 + batch_size: 128 minibatch_size: 4 batch_rows: 4 bptt_horizon: 2 @@ -149,6 +149,7 @@ wrappers: jitter: 0 stream_only: + - episode_stats.EpisodeStatsWrapper: {} - stream_wrapper.StreamWrapper: user: thatguy - exploration.OnResetExplorationWrapper: diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 530ed50..94375f7 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -754,45 +754,38 @@ def cut_if_next(self): # Gym trees apparently get the same tile map as outside bushes # GYM = 7 if (in_overworld and 0x3D in up) or (in_erika_gym and 0x50 in up): - self.pyboy.send_input(WindowEvent.PRESS_ARROW_UP) - self.pyboy.send_input(WindowEvent.RELEASE_ARROW_UP, delay=8) + self.pyboy.button("UP", delay=8) self.pyboy.tick(self.action_freq, render=True) elif (in_overworld and 0x3D in down) or (in_erika_gym and 0x50 in down): - self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN) - self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN, delay=8) + self.pyboy.button("DOWN", delay=8) self.pyboy.tick(self.action_freq, render=True) elif (in_overworld and 0x3D in left) or (in_erika_gym and 0x50 in left): - self.pyboy.send_input(WindowEvent.PRESS_ARROW_LEFT) - self.pyboy.send_input(WindowEvent.RELEASE_ARROW_LEFT, delay=8) + self.pyboy.button("LEFT", delay=8) self.pyboy.tick(self.action_freq, render=True) elif (in_overworld and 0x3D in right) or (in_erika_gym and 0x50 in right): - self.pyboy.send_input(WindowEvent.PRESS_ARROW_RIGHT) - self.pyboy.send_input(WindowEvent.RELEASE_ARROW_RIGHT, delay=8) + self.pyboy.button("RIGHT", delay=8) self.pyboy.tick(self.action_freq, render=True) else: return + breakpoint() # open start menu - self.pyboy.send_input(WindowEvent.PRESS_BUTTON_START) - self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_START, delay=8) + self.pyboy.button("START", delay=8) self.pyboy.tick(self.action_freq, render=True) # scroll to pokemon # 1 is the item index for pokemon for _ in range(24): if self.pyboy.memory[self.pyboy.symbol_lookup("wCurrentMenuItem")[1]] == 1: break - self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN) - self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN, delay=8) + self.pyboy.button("DOWN", delay=8) self.pyboy.tick(self.action_freq, render=True) - self.pyboy.send_input(WindowEvent.PRESS_BUTTON_A) - self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_A, delay=8) + self.pyboy.button("A", delay=8) self.pyboy.tick(self.action_freq, render=True) # find pokemon with cut # We run this over all pokemon so we dont end up in an infinite for loop for _ in range(7): - self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN) - self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN, delay=8) + self.pyboy.button("DOWN", delay=8) self.pyboy.tick(self.action_freq, render=True) party_mon = self.pyboy.memory[self.pyboy.symbol_lookup("wCurrentMenuItem")[1]] _, addr = self.pyboy.symbol_lookup(f"wPartyMon{party_mon%6+1}Moves") @@ -800,8 +793,7 @@ def cut_if_next(self): break # Enter submenu - self.pyboy.send_input(WindowEvent.PRESS_BUTTON_A) - self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_A, delay=8) + self.pyboy.button("A", delay=8) self.pyboy.tick(4 * self.action_freq, render=True) # Scroll until the field move is found @@ -812,14 +804,12 @@ def cut_if_next(self): current_item = self.read_m("wCurrentMenuItem") if current_item < 4 and FieldMoves.CUT.value == field_moves[current_item]: break - self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN) - self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN, delay=8) + self.pyboy.button("DOWN", delay=8) self.pyboy.tick(self.action_freq, render=True) # press a bunch of times for _ in range(5): - self.pyboy.send_input(WindowEvent.PRESS_BUTTON_A) - self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_A, delay=8) + self.pyboy.button("A", delay=8) self.pyboy.tick(4 * self.action_freq, render=True) def surf_if_attempt(self, action: WindowEvent): diff --git a/pokemonred_puffer/eval.py b/pokemonred_puffer/eval.py index f61fb1b..b7c3536 100644 --- a/pokemonred_puffer/eval.py +++ b/pokemonred_puffer/eval.py @@ -3,11 +3,13 @@ import cv2 import matplotlib.colors as mcolors import numpy as np +from numba import jit KANTO_MAP_PATH = os.path.join(os.path.dirname(__file__), "kanto_map_dsv.png") BACKGROUND = np.array(cv2.imread(KANTO_MAP_PATH)) +@jit(nopython=True, nogil=True, parallel=True) def make_pokemon_red_overlay(counts: np.ndarray): # TODO: Rethink how this scaling works # Divide by number of elements > 0 diff --git a/pokemonred_puffer/train.py b/pokemonred_puffer/train.py index 02df802..420e2f2 100644 --- a/pokemonred_puffer/train.py +++ b/pokemonred_puffer/train.py @@ -93,7 +93,6 @@ def env_creator( env = wrapper_class(env, pufferlib.namespace(**[x for x in cfg.values()][0])) if async_wrapper and async_config: env = AsyncWrapper(env, async_config["send_queues"], async_config["recv_queues"]) - # env = pufferlib.postprocess.EpisodeStats(env) return pufferlib.emulation.GymnasiumPufferEnv(env=env) return env_creator diff --git a/pokemonred_puffer/wrappers/episode_stats.py b/pokemonred_puffer/wrappers/episode_stats.py new file mode 100644 index 0000000..7249e6a --- /dev/null +++ b/pokemonred_puffer/wrappers/episode_stats.py @@ -0,0 +1,36 @@ +import numpy as np +import gymnasium + +import pufferlib.utils + + +class EpisodeStatsWrapper(gymnasium.Wrapper): + def __init__(self, env, *args, **kwargs): + self.env = env + self.observation_space = env.observation_space + self.action_space = env.action_space + self.reset() + + # TODO: Fix options. Maybe reimplement gymnasium.Wrapper with better compatibility + def reset(self, seed=None): + self.info = dict(episode_return=0, episode_length=0) + return super().reset(seed=seed) + + def step(self, action): + observation, reward, terminated, truncated, info = super().step(action) + + for k, v in pufferlib.utils.unroll_nested_dict(info): + if "exploration_map" in k: + self.info[k] = self.info.get(k, np.zeros_like(v)) + v + else: + self.info[k] = v + + # self.info['episode_return'].append(reward) + self.info["episode_return"] += reward + self.info["episode_length"] += 1 + + info = {} + if terminated or truncated or self.info["episode_length"] % self.env.log_frequency == 0: + info = self.info + + return observation, reward, terminated, truncated, info diff --git a/pokemonred_puffer/wrappers/stream_wrapper.py b/pokemonred_puffer/wrappers/stream_wrapper.py index 9fc2d8c..4440d36 100644 --- a/pokemonred_puffer/wrappers/stream_wrapper.py +++ b/pokemonred_puffer/wrappers/stream_wrapper.py @@ -15,25 +15,25 @@ class StreamWrapper(gym.Wrapper): def __init__(self, env: RedGymEnv, config: pufferlib.namespace): super().__init__(env) - with RedGymEnv.lock: + with StreamWrapper.lock: env_id = ( - (int(RedGymEnv.env_id.buf[0]) << 24) - + (int(RedGymEnv.env_id.buf[1]) << 16) - + (int(RedGymEnv.env_id.buf[2]) << 8) - + (int(RedGymEnv.env_id.buf[3])) + (int(StreamWrapper.env_id.buf[0]) << 24) + + (int(StreamWrapper.env_id.buf[1]) << 16) + + (int(StreamWrapper.env_id.buf[2]) << 8) + + (int(StreamWrapper.env_id.buf[3])) ) self.env_id = env_id env_id += 1 - RedGymEnv.env_id.buf[0] = (env_id >> 24) & 0xFF - RedGymEnv.env_id.buf[1] = (env_id >> 16) & 0xFF - RedGymEnv.env_id.buf[2] = (env_id >> 8) & 0xFF - RedGymEnv.env_id.buf[3] = (env_id) & 0xFF + StreamWrapper.env_id.buf[0] = (env_id >> 24) & 0xFF + StreamWrapper.env_id.buf[1] = (env_id >> 16) & 0xFF + StreamWrapper.env_id.buf[2] = (env_id >> 8) & 0xFF + StreamWrapper.env_id.buf[3] = (env_id) & 0xFF self.user = config.user self.ws_address = "wss://transdimensional.xyz/broadcast" self.stream_metadata = { "user": self.user, - "env_id": env.env_id, + "env_id": self.env_id, } self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) diff --git a/pyproject.toml b/pyproject.toml index f620214..fc3f09c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ classifiers = [ dependencies = [ "einops", "mediapy", + "numba", "numpy", "opencv-python", "pyboy>=2",