From b99f4a129d95e033b29b1bd4b1a6f95d19beb968 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Tue, 17 Dec 2024 21:55:12 -0500 Subject: [PATCH] Add in pokeflute and surf coords --- config.yaml | 37 +++++++++++++-- pokemonred_puffer/environment.py | 55 ++++++++++++++++++++++- pokemonred_puffer/rewards/baseline.py | 14 +++++- pokemonred_puffer/wrappers/exploration.py | 18 ++++++++ 4 files changed, 118 insertions(+), 6 deletions(-) diff --git a/config.yaml b/config.yaml index d0f9ffb..0521b77 100644 --- a/config.yaml +++ b/config.yaml @@ -66,12 +66,12 @@ env: auto_teach_cut: False auto_use_cut: False auto_use_strength: True - auto_use_surf: True - auto_teach_surf: True + auto_use_surf: False + auto_teach_surf: False auto_teach_strength: True auto_solve_strength_puzzles: True auto_remove_all_nonuseful_items: True - auto_pokeflute: True + auto_pokeflute: False auto_next_elevator_floor: True skip_safari_zone: False infinite_safari_steps: False @@ -341,6 +341,37 @@ rewards: use_surf: 0.4 useful_item: 0.825 safari_zone: 3.4493650422686217 + + baseline.ObjectRewardRequiredEventsMapIdsFieldMoves: + reward: + a_press: 0.0 # 0.00001 + badges: 3.0 + bag_menu: 0.0 + caught_pokemon: 2.5 + valid_cut_coords: 0.75 + valid_surf_coords: .75 + event: .75 + exploration: 0.018999755680454297 + explore_hidden_objs: 0.00009999136567868017 + explore_signs: 0.015025767686371013 + explore_warps: 0.010135211705238394 + hm_count: 7.5 + invalid_cut_coords: 0.0001 + invalid_surf_coords: 0.0001 + level: 1.05 + moves_obtained: 4.0 + pokecenter_heal: 0.47 + pokeflute_coords: 0.0001 + pokemon_menu: 0.0 + required_event: 7.0 + required_item: 3.0 + seen_action_bag_menu: 0.0 + seen_pokemon: 2.5 + start_menu: 0.0 + stats_menu: 0.0 + use_surf: 0.0 + useful_item: 0.825 + safari_zone: 3.4493650422686217 policies: multi_convolutional.MultiConvolutionalPolicy: diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index e66eed9..7f4b3cc 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -270,8 +270,18 @@ def register_hooks(self): ) self.pyboy.hook_register(None, "HandleBlackOut", self.blackout_hook, None) self.pyboy.hook_register(None, "SetLastBlackoutMap.done", self.blackout_update_hook, None) - self.pyboy.hook_register(None, "UsedCut.nothingToCut", self.cut_hook, context=True) - self.pyboy.hook_register(None, "UsedCut.canCut", self.cut_hook, context=False) + if not self.auto_use_cut: + self.pyboy.hook_register(None, "UsedCut.nothingToCut", self.cut_hook, context=True) + self.pyboy.hook_register(None, "UsedCut.canCut", self.cut_hook, context=False) + # there is already an event for waking up the snorlax. No need to make a hookd for it + if not self.auto_pokeflute: + self.pyboy.hook_register( + None, "ItemUsePokeFlute.noSnorlaxToWakeUp", self.pokeflute_hook, None + ) + if not self.auto_use_surf: + self.pyboy.hook_register(None, "SurfingAttemptFailed", self.surf_hook, context=False) + self.pyboy.hook_register(None, "ItemUseSurfboard.surf", self.surf_hook, context=True) + if self.disable_wild_encounters: self.setup_disable_wild_encounters() self.pyboy.hook_register(None, "AnimateHealingMachine", self.pokecenter_heal_hook, None) @@ -423,6 +433,11 @@ def init_mem(self): self.valid_cut_coords = {} self.invalid_cut_coords = {} + self.pokeflute_coords = {} + + self.valid_surf_coords = {} + self.invalid_surf_coords = {} + self.seen_hidden_objs = {} self.seen_signs = {} @@ -1375,6 +1390,39 @@ def cut_hook(self, context: bool): self.cut_explore_map[local_to_global(y, x, map_id)] = 1 + def pokeflute_hook(self, *args, **kwargs): + player_direction = self.pyboy.memory[ + self.pyboy.symbol_lookup("wSpritePlayerStateData1FacingDirection")[1] + ] + x, y, map_id = self.get_game_coords() # x, y, map_id + if player_direction == 0: # down + coords = (x, y + 1, map_id) + if player_direction == 4: + coords = (x, y - 1, map_id) + if player_direction == 8: + coords = (x - 1, y, map_id) + if player_direction == 0xC: + coords = (x + 1, y, map_id) + self.pokeflute_coords[coords] = 1 + + def surf_hook(self, context: bool, *args, **kwargs): + player_direction = self.pyboy.memory[ + self.pyboy.symbol_lookup("wSpritePlayerStateData1FacingDirection")[1] + ] + x, y, map_id = self.get_game_coords() # x, y, map_id + if player_direction == 0: # down + coords = (x, y + 1, map_id) + if player_direction == 4: + coords = (x, y - 1, map_id) + if player_direction == 8: + coords = (x - 1, y, map_id) + if player_direction == 0xC: + coords = (x + 1, y, map_id) + if context: + self.valid_surf_coords[coords] = 1 + else: + self.invalid_surf_coords[coords] = 1 + def disable_wild_encounter_hook(self, *args, **kwargs): if ( self.disable_wild_encounters @@ -1429,6 +1477,9 @@ def agent_stats(self, action): "taught_strength": int(self.check_if_party_has_hm(TmHmMoves.STRENGTH.value)), "valid_cut_coords": len(self.valid_cut_coords), "invalid_cut_coords": len(self.invalid_cut_coords), + "pokeflute_coords": len(self.pokeflute_coords), + "valid_surf_coords": len(self.valid_surf_coords), + "invalid_surf_coords": len(self.invalid_surf_coords), "menu": { "start_menu": self.seen_start_menu, "pokemon_menu": self.seen_pokemon_menu, diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index 97bc5de..b93f2f9 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -361,7 +361,7 @@ def get_levels_reward(self): class ObjectRewardRequiredEventsMapIds(BaselineRewardEnv): - def get_game_state_reward(self): + def get_game_state_reward(self) -> dict[str, float]: _, wBagItems = self.pyboy.symbol_lookup("wBagItems") numBagItems = self.read_m("wNumBagItems") bag_item_ids = set(self.pyboy.memory[wBagItems : wBagItems + 2 * numBagItems : 2]) @@ -430,3 +430,15 @@ def get_levels_reward(self): return self.max_level_sum else: return 15 + (self.max_level_sum - 15) / 4 + + +class ObjectRewardRequiredEventsMapIdsFieldMoves(ObjectRewardRequiredEventsMapIds): + def get_game_state_reward(self) -> dict[str, float]: + return super().get_game_state_reward() | { + "pokeflute_coords": self.reward_config["pokeflute_coords"] + * len(self.pokeflute_coords.values()), + "valid_surf_coords": self.reward_config["valid_surf_coords"] + * len(self.valid_surf_coords.values()), + "invalid_surf_coords": self.reward_config["invalid_surf_coords"] + * len(self.invalid_cut_coords.values()), + } diff --git a/pokemonred_puffer/wrappers/exploration.py b/pokemonred_puffer/wrappers/exploration.py index 196daf4..ebcb7f4 100644 --- a/pokemonred_puffer/wrappers/exploration.py +++ b/pokemonred_puffer/wrappers/exploration.py @@ -126,6 +126,9 @@ def step(self, action): self.env.unwrapped.seen_npcs.clear() self.env.unwrapped.valid_cut_coords.clear() self.env.unwrapped.invalid_cut_coords.clear() + self.env.unwrapped.pokeflute_coords.clear() + self.env.unwrapped.valid_surf_coords.clear() + self.env.unwrapped.invalid_surf_coords.clear() self.env.unwrapped.seen_warps.clear() self.env.unwrapped.seen_hidden_objs.clear() self.env.unwrapped.seen_signs.clear() @@ -166,6 +169,21 @@ def step(self, action): for k, v in self.env.unwrapped.seen_npcs.items() if v > 0 ) + self.env.unwrapped.pokeflute_coords.update( + (k, self.fixed_value["pokeflute"]) + for k, v in self.env.unwrapped.seen_npcs.items() + if v > 0 + ) + self.env.unwrapped.valid_surf_coords.update( + (k, self.fixed_value["valid_surf"]) + for k, v in self.env.unwrapped.seen_npcs.items() + if v > 0 + ) + self.env.unwrapped.invalid_surf_coords.update( + (k, self.fixed_value["invalid_surf"]) + for k, v in self.env.unwrapped.seen_npcs.items() + if v > 0 + ) self.env.unwrapped.explore_map[self.env.unwrapped.explore_map > 0] = self.fixed_value[ "explore" ]