Skip to content

Commit

Permalink
More prep for strength and surf
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Aug 12, 2024
1 parent 1dccf2a commit 65a226b
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 46 deletions.
121 changes: 75 additions & 46 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(self, env_config: pufferlib.namespace):
self.auto_solve_strength_puzzles = env_config.auto_solve_strength_puzzles
self.auto_remove_all_nonuseful_items = env_config.auto_remove_all_nonuseful_items
self.auto_pokeflute = env_config.auto_pokeflute
self.skip_safari_zone = env_config.skip_safari_zone
self.infinite_money = env_config.infinite_money
self.use_global_map = env_config.use_global_map
self.save_state = env_config.save_state
Expand Down Expand Up @@ -173,6 +174,8 @@ def __init__(self, env_config: pufferlib.namespace):
"blackout_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8),
"battle_type": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8),
"cut_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
"strength_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
"surf_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
# "x": spaces.Box(low=0, high=255, shape=(1,), dtype=np.u`int8),
# "y": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8),
"map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8),
Expand All @@ -196,8 +199,9 @@ def __init__(self, env_config: pufferlib.namespace):
"moves": spaces.Box(low=0, high=0xA4, shape=(6, 4), dtype=np.uint8),
# Add 3 for rival_3, game corner rocket and saffron guard
"events": spaces.Box(low=0, high=1, shape=(len(EVENTS) + 3,), dtype=np.uint8),
"safari_steps": spaces.Box(low=0, high=1.0, shape=(1,), dtype=np.uint32),
}
if not self.skip_safari_zone:
obs_dict["safari_steps"] = spaces.Box(low=0, high=1.0, shape=(1,), dtype=np.uint32)

if self.use_global_map:
obs_dict["global_map"] = spaces.Box(
Expand Down Expand Up @@ -351,7 +355,9 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] =
self.update_pokedex()
self.update_tm_hm_moves_obtained()
self.party_size = self.read_m("wPartyCount")
self.taught_cut = self.check_if_party_has_hm(0xF)
self.taught_cut = self.check_if_party_has_hm(TmHmMoves.CUT.value)
self.taught_surf = self.check_if_party_has_hm(TmHmMoves.SURF.value)
self.taught_strength = self.check_if_party_has_hm(TmHmMoves.STRENGTH.value)
self.levels_satisfied = False
self.base_explore = 0
self.max_opponent_level = 0
Expand Down Expand Up @@ -575,41 +581,58 @@ def _get_obs(self):
# item ids start at 1 so using 0 as the nothing value is okay
bag[2 * numBagItems :] = 0

return self.render() | {
"direction": np.array(
self.read_m("wSpritePlayerStateData1FacingDirection") // 4, dtype=np.uint8
),
"blackout_map_id": np.array(self.read_m("wLastBlackoutMap"), dtype=np.uint8),
"battle_type": np.array(self.read_m("wIsInBattle") + 1, dtype=np.uint8),
"cut_in_party": np.array(self.check_if_party_has_hm(0xF), dtype=np.uint8),
# "x": np.array(player_x, dtype=np.uint8),
# "y": np.array(player_y, dtype=np.uint8),
"map_id": np.array(self.read_m(0xD35E), dtype=np.uint8),
"bag_items": bag[::2].copy(),
"bag_quantity": bag[1::2].copy(),
"species": np.array([self.party[i].Species for i in range(6)], dtype=np.uint8),
"hp": np.array([self.party[i].HP for i in range(6)], dtype=np.uint32),
"status": np.array([self.party[i].Status for i in range(6)], dtype=np.uint8),
"type1": np.array([self.party[i].Type1 for i in range(6)], dtype=np.uint8),
"type2": np.array([self.party[i].Type2 for i in range(6)], dtype=np.uint8),
"level": np.array([self.party[i].Level for i in range(6)], dtype=np.uint8),
"maxHP": np.array([self.party[i].MaxHP for i in range(6)], dtype=np.uint32),
"attack": np.array([self.party[i].Attack for i in range(6)], dtype=np.uint32),
"defense": np.array([self.party[i].Defense for i in range(6)], dtype=np.uint32),
"speed": np.array([self.party[i].Speed for i in range(6)], dtype=np.uint32),
"special": np.array([self.party[i].Special for i in range(6)], dtype=np.uint32),
"moves": np.array([self.party[i].Moves for i in range(6)], dtype=np.uint8),
"events": np.array(
[self.events.get_event(event) for event in EVENTS]
+ [
self.read_m("wSSAnne2FCurScript") == 4, # rival 3
self.missables.get_missable("HS_GAME_CORNER_ROCKET"), # game corner rocket
self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK"), # saffron guard
],
dtype=np.uint8,
),
"safari_steps": np.array(self.read_short("wSafariSteps"), dtype=np.uint32),
}
return (
self.render()
| {
"direction": np.array(
self.read_m("wSpritePlayerStateData1FacingDirection") // 4, dtype=np.uint8
),
"blackout_map_id": np.array(self.read_m("wLastBlackoutMap"), dtype=np.uint8),
"battle_type": np.array(self.read_m("wIsInBattle") + 1, dtype=np.uint8),
"cut_in_party": np.array(
self.check_if_party_has_hm(TmHmMoves.CUT.value), dtype=np.uint8
),
"surf_in_party": np.array(
self.check_if_party_has_hm(TmHmMoves.SURF.value), dtype=np.uint8
),
"strength_in_party": np.array(
self.check_if_party_has_hm(TmHmMoves.STRENGTH.value), dtype=np.uint8
),
# "x": np.array(player_x, dtype=np.uint8),
# "y": np.array(player_y, dtype=np.uint8),
"map_id": np.array(self.read_m(0xD35E), dtype=np.uint8),
"bag_items": bag[::2].copy(),
"bag_quantity": bag[1::2].copy(),
"species": np.array([self.party[i].Species for i in range(6)], dtype=np.uint8),
"hp": np.array([self.party[i].HP for i in range(6)], dtype=np.uint32),
"status": np.array([self.party[i].Status for i in range(6)], dtype=np.uint8),
"type1": np.array([self.party[i].Type1 for i in range(6)], dtype=np.uint8),
"type2": np.array([self.party[i].Type2 for i in range(6)], dtype=np.uint8),
"level": np.array([self.party[i].Level for i in range(6)], dtype=np.uint8),
"maxHP": np.array([self.party[i].MaxHP for i in range(6)], dtype=np.uint32),
"attack": np.array([self.party[i].Attack for i in range(6)], dtype=np.uint32),
"defense": np.array([self.party[i].Defense for i in range(6)], dtype=np.uint32),
"speed": np.array([self.party[i].Speed for i in range(6)], dtype=np.uint32),
"special": np.array([self.party[i].Special for i in range(6)], dtype=np.uint32),
"moves": np.array([self.party[i].Moves for i in range(6)], dtype=np.uint8),
"events": np.array(
[self.events.get_event(event) for event in EVENTS]
+ [
self.read_m("wSSAnne2FCurScript") == 4, # rival 3
self.missables.get_missable("HS_GAME_CORNER_ROCKET"), # game corner rocket
self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK"), # saffron guard
],
dtype=np.uint8,
),
}
| (
{}
if self.skip_safari_zone
else {
"safari_steps": np.array(self.read_short("wSafariSteps"), dtype=np.uint32),
}
)
)

def set_perfect_iv_dvs(self):
party_size = self.read_m("wPartyCount")
Expand Down Expand Up @@ -665,7 +688,9 @@ def step(self, action):
self.update_map_progress()
if self.perfect_ivs:
self.set_perfect_iv_dvs()
self.taught_cut = self.check_if_party_has_hm(0xF)
self.taught_cut = self.check_if_party_has_hm(TmHmMoves.CUT.value)
self.taught_surf = self.check_if_party_has_hm(TmHmMoves.SURF.value)
self.taught_strength = self.check_if_party_has_hm(TmHmMoves.STRENGTH.value)
self.pokecenters[self.read_m("wLastBlackoutMap")] = 1
info = {}

Expand Down Expand Up @@ -723,19 +748,21 @@ def run_action_on_emulator(self, action):
self.pyboy.tick(self.action_freq, render=False)

if self.events.get_event("EVENT_GOT_HM01"):
if self.auto_teach_cut and not self.check_if_party_has_hm(0x0F):
if self.auto_teach_cut and not self.check_if_party_has_hm(TmHmMoves.CUT.value):
self.teach_hm(TmHmMoves.CUT.value, 30, CUT_SPECIES_IDS)
if self.auto_use_cut:
self.cut_if_next()

if self.events.get_event("EVENT_GOT_HM03"):
if self.auto_teach_surf and not self.check_if_party_has_hm(0x39):
if self.auto_teach_surf and not self.check_if_party_has_hm(TmHmMoves.SURF.value):
self.teach_hm(TmHmMoves.SURF.value, 15, SURF_SPECIES_IDS)
if self.auto_use_surf:
self.surf_if_attempt(VALID_ACTIONS[action])

if self.events.get_event("EVENT_GOT_HM04"):
if self.auto_teach_strength and not self.check_if_party_has_hm(0x46):
if self.auto_teach_strength and not self.check_if_party_has_hm(
TmHmMoves.STRENGTH.value
):
self.teach_hm(TmHmMoves.STRENGTH.value, 15, STRENGTH_SPECIES_IDS)
if self.auto_solve_strength_puzzles:
self.solve_missable_strength_puzzle()
Expand All @@ -745,7 +772,7 @@ def run_action_on_emulator(self, action):
self.use_pokeflute()

if self.get_game_coords() == (18, 4, 7) and self.skip_safari_zone:
self.skip_safari_zone()
self.skip_safari_zone_atn()

# One last tick just in case
self.pyboy.tick(1, render=True)
Expand Down Expand Up @@ -1141,7 +1168,7 @@ def solve_switch_strength_puzzle(self):
self.setup_enable_wild_ecounters()
break

def skip_safari_zone(self):
def skip_safari_zone_atn(self):
# First move down
self.pyboy.button("down", 8)
self.pyboy.tick(self.action_freq, render=self.animate_scripts)
Expand Down Expand Up @@ -1292,9 +1319,11 @@ def agent_stats(self, action):
"seen_pokemon": int(sum(self.seen_pokemon)),
"moves_obtained": int(sum(self.moves_obtained)),
"opponent_level": self.max_opponent_level,
"taught_cut": int(self.check_if_party_has_hm(0xF)),
"cut_coords": sum(self.cut_coords.values()),
"cut_tiles": len(self.cut_tiles),
"taught_cut": int(self.check_if_party_has_hm(TmHmMoves.CUT.value)),
"taught_surf": int(self.check_if_party_has_hm(TmHmMoves.SURF.value)),
"taught_strength": int(self.check_if_party_has_hm(TmHmMoves.STRENGTH.value)),
# "cut_coords": sum(self.cut_coords.values()),
# "cut_tiles": len(self.cut_tiles),
"menu": {
"start_menu": self.seen_start_menu,
"pokemon_menu": self.seen_pokemon_menu,
Expand Down
5 changes: 5 additions & 0 deletions pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
self.value_fn = nn.LazyLinear(1)

self.two_bit = env.unwrapped.env.two_bit
self.skip_safari_zone = env.unwrapped.skip_safari_zone
self.use_global_map = env.unwrapped.env.use_global_map

if self.use_global_map:
Expand Down Expand Up @@ -216,6 +217,8 @@ def encode_observations(self, observations):
one_hot(observations["battle_type"].int(), 4).float().squeeze(1),
# observations["cut_event"].float(),
observations["cut_in_party"].float(),
observations["strength_in_party"].float(),
observations["surf_in_party"].float(),
# observations["x"].float(),
# observations["y"].float(),
# one_hot(observations["map_id"].int(), 0xF7).float().squeeze(1),
Expand All @@ -229,6 +232,8 @@ def encode_observations(self, observations):
),
dim=-1,
)
if self.skip_safari_zone:
cat_obs = torch.cat((cat_obs, observations["safari_steps"].float()), dim=-1)
if self.use_global_map:
cat_obs = torch.cat(
(
Expand Down

0 comments on commit 65a226b

Please sign in to comment.