From 65a226b692a91e8fed905db1455d761efa420132 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Sun, 11 Aug 2024 21:11:39 -0400 Subject: [PATCH] More prep for strength and surf --- pokemonred_puffer/environment.py | 121 +++++++++++------- .../policies/multi_convolutional.py | 5 + 2 files changed, 80 insertions(+), 46 deletions(-) diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 5ae97d7..32cae14 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -120,6 +120,7 @@ def __init__(self, env_config: pufferlib.namespace): self.auto_solve_strength_puzzles = env_config.auto_solve_strength_puzzles self.auto_remove_all_nonuseful_items = env_config.auto_remove_all_nonuseful_items self.auto_pokeflute = env_config.auto_pokeflute + self.skip_safari_zone = env_config.skip_safari_zone self.infinite_money = env_config.infinite_money self.use_global_map = env_config.use_global_map self.save_state = env_config.save_state @@ -173,6 +174,8 @@ def __init__(self, env_config: pufferlib.namespace): "blackout_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), "battle_type": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8), "cut_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), + "strength_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), + "surf_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), # "x": spaces.Box(low=0, high=255, shape=(1,), dtype=np.u`int8), # "y": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), "map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), @@ -196,8 +199,9 @@ def __init__(self, env_config: pufferlib.namespace): "moves": spaces.Box(low=0, high=0xA4, shape=(6, 4), dtype=np.uint8), # Add 3 for rival_3, game corner rocket and saffron guard "events": spaces.Box(low=0, high=1, shape=(len(EVENTS) + 3,), dtype=np.uint8), - "safari_steps": spaces.Box(low=0, high=1.0, shape=(1,), dtype=np.uint32), } + if not self.skip_safari_zone: + obs_dict["safari_steps"] = spaces.Box(low=0, high=1.0, shape=(1,), dtype=np.uint32) if self.use_global_map: obs_dict["global_map"] = spaces.Box( @@ -351,7 +355,9 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.update_pokedex() self.update_tm_hm_moves_obtained() self.party_size = self.read_m("wPartyCount") - self.taught_cut = self.check_if_party_has_hm(0xF) + self.taught_cut = self.check_if_party_has_hm(TmHmMoves.CUT.value) + self.taught_surf = self.check_if_party_has_hm(TmHmMoves.SURF.value) + self.taught_strength = self.check_if_party_has_hm(TmHmMoves.STRENGTH.value) self.levels_satisfied = False self.base_explore = 0 self.max_opponent_level = 0 @@ -575,41 +581,58 @@ def _get_obs(self): # item ids start at 1 so using 0 as the nothing value is okay bag[2 * numBagItems :] = 0 - return self.render() | { - "direction": np.array( - self.read_m("wSpritePlayerStateData1FacingDirection") // 4, dtype=np.uint8 - ), - "blackout_map_id": np.array(self.read_m("wLastBlackoutMap"), dtype=np.uint8), - "battle_type": np.array(self.read_m("wIsInBattle") + 1, dtype=np.uint8), - "cut_in_party": np.array(self.check_if_party_has_hm(0xF), dtype=np.uint8), - # "x": np.array(player_x, dtype=np.uint8), - # "y": np.array(player_y, dtype=np.uint8), - "map_id": np.array(self.read_m(0xD35E), dtype=np.uint8), - "bag_items": bag[::2].copy(), - "bag_quantity": bag[1::2].copy(), - "species": np.array([self.party[i].Species for i in range(6)], dtype=np.uint8), - "hp": np.array([self.party[i].HP for i in range(6)], dtype=np.uint32), - "status": np.array([self.party[i].Status for i in range(6)], dtype=np.uint8), - "type1": np.array([self.party[i].Type1 for i in range(6)], dtype=np.uint8), - "type2": np.array([self.party[i].Type2 for i in range(6)], dtype=np.uint8), - "level": np.array([self.party[i].Level for i in range(6)], dtype=np.uint8), - "maxHP": np.array([self.party[i].MaxHP for i in range(6)], dtype=np.uint32), - "attack": np.array([self.party[i].Attack for i in range(6)], dtype=np.uint32), - "defense": np.array([self.party[i].Defense for i in range(6)], dtype=np.uint32), - "speed": np.array([self.party[i].Speed for i in range(6)], dtype=np.uint32), - "special": np.array([self.party[i].Special for i in range(6)], dtype=np.uint32), - "moves": np.array([self.party[i].Moves for i in range(6)], dtype=np.uint8), - "events": np.array( - [self.events.get_event(event) for event in EVENTS] - + [ - self.read_m("wSSAnne2FCurScript") == 4, # rival 3 - self.missables.get_missable("HS_GAME_CORNER_ROCKET"), # game corner rocket - self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK"), # saffron guard - ], - dtype=np.uint8, - ), - "safari_steps": np.array(self.read_short("wSafariSteps"), dtype=np.uint32), - } + return ( + self.render() + | { + "direction": np.array( + self.read_m("wSpritePlayerStateData1FacingDirection") // 4, dtype=np.uint8 + ), + "blackout_map_id": np.array(self.read_m("wLastBlackoutMap"), dtype=np.uint8), + "battle_type": np.array(self.read_m("wIsInBattle") + 1, dtype=np.uint8), + "cut_in_party": np.array( + self.check_if_party_has_hm(TmHmMoves.CUT.value), dtype=np.uint8 + ), + "surf_in_party": np.array( + self.check_if_party_has_hm(TmHmMoves.SURF.value), dtype=np.uint8 + ), + "strength_in_party": np.array( + self.check_if_party_has_hm(TmHmMoves.STRENGTH.value), dtype=np.uint8 + ), + # "x": np.array(player_x, dtype=np.uint8), + # "y": np.array(player_y, dtype=np.uint8), + "map_id": np.array(self.read_m(0xD35E), dtype=np.uint8), + "bag_items": bag[::2].copy(), + "bag_quantity": bag[1::2].copy(), + "species": np.array([self.party[i].Species for i in range(6)], dtype=np.uint8), + "hp": np.array([self.party[i].HP for i in range(6)], dtype=np.uint32), + "status": np.array([self.party[i].Status for i in range(6)], dtype=np.uint8), + "type1": np.array([self.party[i].Type1 for i in range(6)], dtype=np.uint8), + "type2": np.array([self.party[i].Type2 for i in range(6)], dtype=np.uint8), + "level": np.array([self.party[i].Level for i in range(6)], dtype=np.uint8), + "maxHP": np.array([self.party[i].MaxHP for i in range(6)], dtype=np.uint32), + "attack": np.array([self.party[i].Attack for i in range(6)], dtype=np.uint32), + "defense": np.array([self.party[i].Defense for i in range(6)], dtype=np.uint32), + "speed": np.array([self.party[i].Speed for i in range(6)], dtype=np.uint32), + "special": np.array([self.party[i].Special for i in range(6)], dtype=np.uint32), + "moves": np.array([self.party[i].Moves for i in range(6)], dtype=np.uint8), + "events": np.array( + [self.events.get_event(event) for event in EVENTS] + + [ + self.read_m("wSSAnne2FCurScript") == 4, # rival 3 + self.missables.get_missable("HS_GAME_CORNER_ROCKET"), # game corner rocket + self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK"), # saffron guard + ], + dtype=np.uint8, + ), + } + | ( + {} + if self.skip_safari_zone + else { + "safari_steps": np.array(self.read_short("wSafariSteps"), dtype=np.uint32), + } + ) + ) def set_perfect_iv_dvs(self): party_size = self.read_m("wPartyCount") @@ -665,7 +688,9 @@ def step(self, action): self.update_map_progress() if self.perfect_ivs: self.set_perfect_iv_dvs() - self.taught_cut = self.check_if_party_has_hm(0xF) + self.taught_cut = self.check_if_party_has_hm(TmHmMoves.CUT.value) + self.taught_surf = self.check_if_party_has_hm(TmHmMoves.SURF.value) + self.taught_strength = self.check_if_party_has_hm(TmHmMoves.STRENGTH.value) self.pokecenters[self.read_m("wLastBlackoutMap")] = 1 info = {} @@ -723,19 +748,21 @@ def run_action_on_emulator(self, action): self.pyboy.tick(self.action_freq, render=False) if self.events.get_event("EVENT_GOT_HM01"): - if self.auto_teach_cut and not self.check_if_party_has_hm(0x0F): + if self.auto_teach_cut and not self.check_if_party_has_hm(TmHmMoves.CUT.value): self.teach_hm(TmHmMoves.CUT.value, 30, CUT_SPECIES_IDS) if self.auto_use_cut: self.cut_if_next() if self.events.get_event("EVENT_GOT_HM03"): - if self.auto_teach_surf and not self.check_if_party_has_hm(0x39): + if self.auto_teach_surf and not self.check_if_party_has_hm(TmHmMoves.SURF.value): self.teach_hm(TmHmMoves.SURF.value, 15, SURF_SPECIES_IDS) if self.auto_use_surf: self.surf_if_attempt(VALID_ACTIONS[action]) if self.events.get_event("EVENT_GOT_HM04"): - if self.auto_teach_strength and not self.check_if_party_has_hm(0x46): + if self.auto_teach_strength and not self.check_if_party_has_hm( + TmHmMoves.STRENGTH.value + ): self.teach_hm(TmHmMoves.STRENGTH.value, 15, STRENGTH_SPECIES_IDS) if self.auto_solve_strength_puzzles: self.solve_missable_strength_puzzle() @@ -745,7 +772,7 @@ def run_action_on_emulator(self, action): self.use_pokeflute() if self.get_game_coords() == (18, 4, 7) and self.skip_safari_zone: - self.skip_safari_zone() + self.skip_safari_zone_atn() # One last tick just in case self.pyboy.tick(1, render=True) @@ -1141,7 +1168,7 @@ def solve_switch_strength_puzzle(self): self.setup_enable_wild_ecounters() break - def skip_safari_zone(self): + def skip_safari_zone_atn(self): # First move down self.pyboy.button("down", 8) self.pyboy.tick(self.action_freq, render=self.animate_scripts) @@ -1292,9 +1319,11 @@ def agent_stats(self, action): "seen_pokemon": int(sum(self.seen_pokemon)), "moves_obtained": int(sum(self.moves_obtained)), "opponent_level": self.max_opponent_level, - "taught_cut": int(self.check_if_party_has_hm(0xF)), - "cut_coords": sum(self.cut_coords.values()), - "cut_tiles": len(self.cut_tiles), + "taught_cut": int(self.check_if_party_has_hm(TmHmMoves.CUT.value)), + "taught_surf": int(self.check_if_party_has_hm(TmHmMoves.SURF.value)), + "taught_strength": int(self.check_if_party_has_hm(TmHmMoves.STRENGTH.value)), + # "cut_coords": sum(self.cut_coords.values()), + # "cut_tiles": len(self.cut_tiles), "menu": { "start_menu": self.seen_start_menu, "pokemon_menu": self.seen_pokemon_menu, diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index a3d603c..848da62 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -58,6 +58,7 @@ def __init__( self.value_fn = nn.LazyLinear(1) self.two_bit = env.unwrapped.env.two_bit + self.skip_safari_zone = env.unwrapped.skip_safari_zone self.use_global_map = env.unwrapped.env.use_global_map if self.use_global_map: @@ -216,6 +217,8 @@ def encode_observations(self, observations): one_hot(observations["battle_type"].int(), 4).float().squeeze(1), # observations["cut_event"].float(), observations["cut_in_party"].float(), + observations["strength_in_party"].float(), + observations["surf_in_party"].float(), # observations["x"].float(), # observations["y"].float(), # one_hot(observations["map_id"].int(), 0xF7).float().squeeze(1), @@ -229,6 +232,8 @@ def encode_observations(self, observations): ), dim=-1, ) + if self.skip_safari_zone: + cat_obs = torch.cat((cat_obs, observations["safari_steps"].float()), dim=-1) if self.use_global_map: cat_obs = torch.cat( (