Skip to content

Commit

Permalink
Custom episode stats wrapper to minimize data passed through infos
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 26, 2024
1 parent 8eb0d5c commit 3cc3b2e
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 34 deletions.
3 changes: 2 additions & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ debug:
env_batch_size: 4
env_pool: True
zero_copy: False
batch_size: 4
batch_size: 128
minibatch_size: 4
batch_rows: 4
bptt_horizon: 2
Expand Down Expand Up @@ -149,6 +149,7 @@ wrappers:
jitter: 0

stream_only:
- episode_stats.EpisodeStatsWrapper: {}
- stream_wrapper.StreamWrapper:
user: thatguy
- exploration.OnResetExplorationWrapper:
Expand Down
34 changes: 12 additions & 22 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,54 +754,46 @@ def cut_if_next(self):
# Gym trees apparently get the same tile map as outside bushes
# GYM = 7
if (in_overworld and 0x3D in up) or (in_erika_gym and 0x50 in up):
self.pyboy.send_input(WindowEvent.PRESS_ARROW_UP)
self.pyboy.send_input(WindowEvent.RELEASE_ARROW_UP, delay=8)
self.pyboy.button("UP", delay=8)
self.pyboy.tick(self.action_freq, render=True)
elif (in_overworld and 0x3D in down) or (in_erika_gym and 0x50 in down):
self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN)
self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN, delay=8)
self.pyboy.button("DOWN", delay=8)
self.pyboy.tick(self.action_freq, render=True)
elif (in_overworld and 0x3D in left) or (in_erika_gym and 0x50 in left):
self.pyboy.send_input(WindowEvent.PRESS_ARROW_LEFT)
self.pyboy.send_input(WindowEvent.RELEASE_ARROW_LEFT, delay=8)
self.pyboy.button("LEFT", delay=8)
self.pyboy.tick(self.action_freq, render=True)
elif (in_overworld and 0x3D in right) or (in_erika_gym and 0x50 in right):
self.pyboy.send_input(WindowEvent.PRESS_ARROW_RIGHT)
self.pyboy.send_input(WindowEvent.RELEASE_ARROW_RIGHT, delay=8)
self.pyboy.button("RIGHT", delay=8)
self.pyboy.tick(self.action_freq, render=True)
else:
return
breakpoint()

# open start menu
self.pyboy.send_input(WindowEvent.PRESS_BUTTON_START)
self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_START, delay=8)
self.pyboy.button("START", delay=8)
self.pyboy.tick(self.action_freq, render=True)
# scroll to pokemon
# 1 is the item index for pokemon
for _ in range(24):
if self.pyboy.memory[self.pyboy.symbol_lookup("wCurrentMenuItem")[1]] == 1:
break
self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN)
self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN, delay=8)
self.pyboy.button("DOWN", delay=8)
self.pyboy.tick(self.action_freq, render=True)
self.pyboy.send_input(WindowEvent.PRESS_BUTTON_A)
self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_A, delay=8)
self.pyboy.button("A", delay=8)
self.pyboy.tick(self.action_freq, render=True)

# find pokemon with cut
# We run this over all pokemon so we dont end up in an infinite for loop
for _ in range(7):
self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN)
self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN, delay=8)
self.pyboy.button("DOWN", delay=8)
self.pyboy.tick(self.action_freq, render=True)
party_mon = self.pyboy.memory[self.pyboy.symbol_lookup("wCurrentMenuItem")[1]]
_, addr = self.pyboy.symbol_lookup(f"wPartyMon{party_mon%6+1}Moves")
if 0xF in self.pyboy.memory[addr : addr + 4]:
break

# Enter submenu
self.pyboy.send_input(WindowEvent.PRESS_BUTTON_A)
self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_A, delay=8)
self.pyboy.button("A", delay=8)
self.pyboy.tick(4 * self.action_freq, render=True)

# Scroll until the field move is found
Expand All @@ -812,14 +804,12 @@ def cut_if_next(self):
current_item = self.read_m("wCurrentMenuItem")
if current_item < 4 and FieldMoves.CUT.value == field_moves[current_item]:
break
self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN)
self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN, delay=8)
self.pyboy.button("DOWN", delay=8)
self.pyboy.tick(self.action_freq, render=True)

# press a bunch of times
for _ in range(5):
self.pyboy.send_input(WindowEvent.PRESS_BUTTON_A)
self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_A, delay=8)
self.pyboy.button("A", delay=8)
self.pyboy.tick(4 * self.action_freq, render=True)

def surf_if_attempt(self, action: WindowEvent):
Expand Down
2 changes: 2 additions & 0 deletions pokemonred_puffer/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import cv2
import matplotlib.colors as mcolors
import numpy as np
from numba import jit

KANTO_MAP_PATH = os.path.join(os.path.dirname(__file__), "kanto_map_dsv.png")
BACKGROUND = np.array(cv2.imread(KANTO_MAP_PATH))


@jit(nopython=True, nogil=True, parallel=True)
def make_pokemon_red_overlay(counts: np.ndarray):
# TODO: Rethink how this scaling works
# Divide by number of elements > 0
Expand Down
1 change: 0 additions & 1 deletion pokemonred_puffer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def env_creator(
env = wrapper_class(env, pufferlib.namespace(**[x for x in cfg.values()][0]))
if async_wrapper and async_config:
env = AsyncWrapper(env, async_config["send_queues"], async_config["recv_queues"])
# env = pufferlib.postprocess.EpisodeStats(env)
return pufferlib.emulation.GymnasiumPufferEnv(env=env)

return env_creator
Expand Down
36 changes: 36 additions & 0 deletions pokemonred_puffer/wrappers/episode_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np
import gymnasium

import pufferlib.utils


class EpisodeStatsWrapper(gymnasium.Wrapper):
def __init__(self, env, *args, **kwargs):
self.env = env
self.observation_space = env.observation_space
self.action_space = env.action_space
self.reset()

# TODO: Fix options. Maybe reimplement gymnasium.Wrapper with better compatibility
def reset(self, seed=None):
self.info = dict(episode_return=0, episode_length=0)
return super().reset(seed=seed)

def step(self, action):
observation, reward, terminated, truncated, info = super().step(action)

for k, v in pufferlib.utils.unroll_nested_dict(info):
if "exploration_map" in k:
self.info[k] = self.info.get(k, np.zeros_like(v)) + v
else:
self.info[k] = v

# self.info['episode_return'].append(reward)
self.info["episode_return"] += reward
self.info["episode_length"] += 1

info = {}
if terminated or truncated or self.info["episode_length"] % self.env.log_frequency == 0:
info = self.info

return observation, reward, terminated, truncated, info
20 changes: 10 additions & 10 deletions pokemonred_puffer/wrappers/stream_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,25 @@ class StreamWrapper(gym.Wrapper):

def __init__(self, env: RedGymEnv, config: pufferlib.namespace):
super().__init__(env)
with RedGymEnv.lock:
with StreamWrapper.lock:
env_id = (
(int(RedGymEnv.env_id.buf[0]) << 24)
+ (int(RedGymEnv.env_id.buf[1]) << 16)
+ (int(RedGymEnv.env_id.buf[2]) << 8)
+ (int(RedGymEnv.env_id.buf[3]))
(int(StreamWrapper.env_id.buf[0]) << 24)
+ (int(StreamWrapper.env_id.buf[1]) << 16)
+ (int(StreamWrapper.env_id.buf[2]) << 8)
+ (int(StreamWrapper.env_id.buf[3]))
)
self.env_id = env_id
env_id += 1
RedGymEnv.env_id.buf[0] = (env_id >> 24) & 0xFF
RedGymEnv.env_id.buf[1] = (env_id >> 16) & 0xFF
RedGymEnv.env_id.buf[2] = (env_id >> 8) & 0xFF
RedGymEnv.env_id.buf[3] = (env_id) & 0xFF
StreamWrapper.env_id.buf[0] = (env_id >> 24) & 0xFF
StreamWrapper.env_id.buf[1] = (env_id >> 16) & 0xFF
StreamWrapper.env_id.buf[2] = (env_id >> 8) & 0xFF
StreamWrapper.env_id.buf[3] = (env_id) & 0xFF

self.user = config.user
self.ws_address = "wss://transdimensional.xyz/broadcast"
self.stream_metadata = {
"user": self.user,
"env_id": env.env_id,
"env_id": self.env_id,
}
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ classifiers = [
dependencies = [
"einops",
"mediapy",
"numba",
"numpy",
"opencv-python",
"pyboy>=2",
Expand Down

0 comments on commit 3cc3b2e

Please sign in to comment.