Skip to content

Commit

Permalink
hook based ticking
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jul 28, 2024
1 parent 2385a12 commit e23cf9e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
24 changes: 19 additions & 5 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def __init__(self, env_config: pufferlib.namespace):
# "y": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8),
"map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8),
# "badges": spaces.Box(low=0, high=np.iinfo(np.uint16).max, shape=(1,), dtype=np.uint16),
"wJoyIgnore": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
"bag_items": spaces.Box(
low=0, high=max(Items._value2member_map_.keys()), shape=(20,), dtype=np.uint8
),
Expand Down Expand Up @@ -252,6 +251,7 @@ def register_hooks(self):
if self.disable_wild_encounters:
self.setup_disable_wild_encounters()
self.pyboy.hook_register(None, "AnimateHealingMachine", self.pokecenter_heal_hook, None)
self.pyboy.hook_register(None, "OverworldLoopLessDelay", self.overworld_loop_hook, None)

def setup_disable_wild_encounters(self):
bank, addr = self.pyboy.symbol_lookup("TryDoWildEncounter.gotWildEncounterType")
Expand Down Expand Up @@ -556,7 +556,6 @@ def _get_obs(self):
# "x": np.array(player_x, dtype=np.uint8),
# "y": np.array(player_y, dtype=np.uint8),
"map_id": np.array(self.read_m(0xD35E), dtype=np.uint8),
"wJoyIgnore": np.array(self.read_m("wJoyIgnore"), dtype=np.uint8),
"bag_items": bag[::2].copy(),
"bag_quantity": bag[1::2].copy(),
"species": np.array([self.party[i].Species for i in range(6)], dtype=np.uint8),
Expand Down Expand Up @@ -678,8 +677,17 @@ def run_action_on_emulator(self, action):

if not self.disable_ai_actions:
self.pyboy.send_input(VALID_ACTIONS[action])
self.pyboy.send_input(VALID_RELEASE_ACTIONS[action], delay=8)
self.pyboy.tick(self.action_freq, render=True)
# TODO: Function-ize this
self.user_control = False
while not self.user_control:
self.pyboy.tick(1, render=False)

self.user_control = False
self.pyboy.send_input(VALID_RELEASE_ACTIONS[action])
while not self.user_control:
self.pyboy.tick(1, render=False)
else:
self.pyboy.tick(self.action_freq, render=True)

if self.events.get_event("EVENT_GOT_HM01"):
if self.auto_teach_cut and not self.check_if_party_has_hm(0x0F):
Expand All @@ -703,6 +711,9 @@ def run_action_on_emulator(self, action):
if self.events.get_event("EVENT_GOT_POKE_FLUTE") and self.auto_pokeflute:
self.use_pokeflute()

# One last tick just in case
self.pyboy.tick(1, render=True)

def party_has_cut_capable_mon(self):
# find bulba and replace tackle (first skill) with cut
party_size = self.read_m("wPartyCount")
Expand Down Expand Up @@ -1129,6 +1140,9 @@ def blackout_update_hook(self, *args, **kwargs):
def pokecenter_heal_hook(self, *args, **kwargs):
self.pokecenter_heal = 1

def overworld_loop_hook(self, *args, **kwargs):
self.user_control = True

def cut_hook(self, context):
player_direction = self.pyboy.memory[
self.pyboy.symbol_lookup("wSpritePlayerStateData1FacingDirection")[1]
Expand Down Expand Up @@ -1279,7 +1293,7 @@ def get_game_coords(self):
return (self.read_m(0xD362), self.read_m(0xD361), self.read_m(0xD35E))

def update_seen_coords(self):
inc = 0.0 if (self.read_m("wd736") & 0b1000_0000) else self.exploration_inc
inc = 0.5 if (self.read_m("wd736") & 0b1000_0000) else self.exploration_inc

x_pos, y_pos, map_n = self.get_game_coords()
# self.seen_coords[(x_pos, y_pos, map_n)] = inc
Expand Down
1 change: 0 additions & 1 deletion pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def encode_observations(self, observations):
# badges.float().squeeze(1),
map_id.squeeze(1),
blackout_map_id.squeeze(1),
observations["wJoyIgnore"].float(),
items.flatten(start_dim=1),
party_latent,
event_obs,
Expand Down

0 comments on commit e23cf9e

Please sign in to comment.