From e23cf9e270eb2afa26aa78152bfa871b2a7d52e5 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Sat, 27 Jul 2024 23:01:47 -0400 Subject: [PATCH] hook based ticking --- pokemonred_puffer/environment.py | 24 +++++++++++++++---- .../policies/multi_convolutional.py | 1 - 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 2422f42..b6e3d36 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -168,7 +168,6 @@ def __init__(self, env_config: pufferlib.namespace): # "y": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), "map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), # "badges": spaces.Box(low=0, high=np.iinfo(np.uint16).max, shape=(1,), dtype=np.uint16), - "wJoyIgnore": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), "bag_items": spaces.Box( low=0, high=max(Items._value2member_map_.keys()), shape=(20,), dtype=np.uint8 ), @@ -252,6 +251,7 @@ def register_hooks(self): if self.disable_wild_encounters: self.setup_disable_wild_encounters() self.pyboy.hook_register(None, "AnimateHealingMachine", self.pokecenter_heal_hook, None) + self.pyboy.hook_register(None, "OverworldLoopLessDelay", self.overworld_loop_hook, None) def setup_disable_wild_encounters(self): bank, addr = self.pyboy.symbol_lookup("TryDoWildEncounter.gotWildEncounterType") @@ -556,7 +556,6 @@ def _get_obs(self): # "x": np.array(player_x, dtype=np.uint8), # "y": np.array(player_y, dtype=np.uint8), "map_id": np.array(self.read_m(0xD35E), dtype=np.uint8), - "wJoyIgnore": np.array(self.read_m("wJoyIgnore"), dtype=np.uint8), "bag_items": bag[::2].copy(), "bag_quantity": bag[1::2].copy(), "species": np.array([self.party[i].Species for i in range(6)], dtype=np.uint8), @@ -678,8 +677,17 @@ def run_action_on_emulator(self, action): if not self.disable_ai_actions: self.pyboy.send_input(VALID_ACTIONS[action]) - self.pyboy.send_input(VALID_RELEASE_ACTIONS[action], delay=8) - self.pyboy.tick(self.action_freq, render=True) + # TODO: Function-ize this + self.user_control = False + while not self.user_control: + self.pyboy.tick(1, render=False) + + self.user_control = False + self.pyboy.send_input(VALID_RELEASE_ACTIONS[action]) + while not self.user_control: + self.pyboy.tick(1, render=False) + else: + self.pyboy.tick(self.action_freq, render=True) if self.events.get_event("EVENT_GOT_HM01"): if self.auto_teach_cut and not self.check_if_party_has_hm(0x0F): @@ -703,6 +711,9 @@ def run_action_on_emulator(self, action): if self.events.get_event("EVENT_GOT_POKE_FLUTE") and self.auto_pokeflute: self.use_pokeflute() + # One last tick just in case + self.pyboy.tick(1, render=True) + def party_has_cut_capable_mon(self): # find bulba and replace tackle (first skill) with cut party_size = self.read_m("wPartyCount") @@ -1129,6 +1140,9 @@ def blackout_update_hook(self, *args, **kwargs): def pokecenter_heal_hook(self, *args, **kwargs): self.pokecenter_heal = 1 + def overworld_loop_hook(self, *args, **kwargs): + self.user_control = True + def cut_hook(self, context): player_direction = self.pyboy.memory[ self.pyboy.symbol_lookup("wSpritePlayerStateData1FacingDirection")[1] @@ -1279,7 +1293,7 @@ def get_game_coords(self): return (self.read_m(0xD362), self.read_m(0xD361), self.read_m(0xD35E)) def update_seen_coords(self): - inc = 0.0 if (self.read_m("wd736") & 0b1000_0000) else self.exploration_inc + inc = 0.5 if (self.read_m("wd736") & 0b1000_0000) else self.exploration_inc x_pos, y_pos, map_n = self.get_game_coords() # self.seen_coords[(x_pos, y_pos, map_n)] = inc diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index 34bf21e..8ce9f2b 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -221,7 +221,6 @@ def encode_observations(self, observations): # badges.float().squeeze(1), map_id.squeeze(1), blackout_map_id.squeeze(1), - observations["wJoyIgnore"].float(), items.flatten(start_dim=1), party_latent, event_obs,