Skip to content

Commit

Permalink
More abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 9, 2024
1 parent bc97144 commit 10025e9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
12 changes: 6 additions & 6 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] =

self.update_pokedex()
self.update_tm_hm_moves_obtained()
self.taught_cut = self.check_if_party_has_cut()
self.taught_cut = self.check_if_party_has_hm(0xF)
self.levels_satisfied = False
self.base_explore = 0
self.max_opponent_level = 0
Expand Down Expand Up @@ -857,7 +857,7 @@ def _get_obs(self):
"blackout_map_id": np.array(self.read_m("wLastBlackoutMap"), dtype=np.uint8),
"battle_type": np.array(self.read_m("wIsInBattle") + 1, dtype=np.uint8),
"cut_event": np.array(self.read_bit(0xD803, 0), dtype=np.uint8),
"cut_in_party": np.array(self.check_if_party_has_cut(), dtype=np.uint8),
"cut_in_party": np.array(self.check_if_party_has_hm(0xF), dtype=np.uint8),
# "x": np.array(player_x, dtype=np.uint8),
# "y": np.array(player_y, dtype=np.uint8),
"map_id": np.array(self.read_m(0xD35E), dtype=np.uint8),
Expand All @@ -871,12 +871,12 @@ def set_perfect_iv_dvs(self):
_, addr = self.pyboy.symbol_lookup(f"wPartyMon{i+1}Species")
self.pyboy.memory[addr + 17 : addr + 17 + 12] = 0xFF

def check_if_party_has_cut(self) -> bool:
def check_if_party_has_hm(self, hm: int) -> bool:
party_size = self.read_m("wPartyCount")
for i in range(party_size):
# PRET 1-indexes
_, addr = self.pyboy.symbol_lookup(f"wPartyMon{i+1}Moves")
if 15 in self.pyboy.memory[addr : addr + 4]:
if hm in self.pyboy.memory[addr : addr + 4]:
return True
return False

Expand All @@ -900,7 +900,7 @@ def step(self, action):
self.update_map_progress()
if self.perfect_ivs:
self.set_perfect_iv_dvs()
self.taught_cut = self.check_if_party_has_cut()
self.taught_cut = self.check_if_party_has_hm(0xF)
self.pokecenters[self.read_m("wLastBlackoutMap")] = 1
info = {}

Expand Down Expand Up @@ -1164,7 +1164,7 @@ def agent_stats(self, action):
"left_bills_house_after_helping": int(self.read_bit(0xD7F2, 7)),
"got_hm01": int(self.read_bit(0xD803, 0)),
"rubbed_captains_back": int(self.read_bit(0xD803, 1)),
"taught_cut": int(self.check_if_party_has_cut()),
"taught_cut": int(self.check_if_party_has_hm(0xF)),
"cut_coords": sum(self.cut_coords.values()),
"cut_tiles": len(self.cut_tiles),
"start_menu": self.seen_start_menu,
Expand Down
2 changes: 1 addition & 1 deletion pokemonred_puffer/rewards/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_game_state_reward(self):
# "heal": self.total_healing_rew,
"explore": sum(self.seen_coords.values()) * 0.012,
# "explore_maps": np.sum(self.seen_map_ids) * 0.0001,
"taught_cut": 4 * int(self.check_if_party_has_cut()),
"taught_cut": 4 * int(self.check_if_party_has_hm(0xF)),
"cut_coords": sum(self.cut_coords.values()) * 1.0,
"cut_tiles": sum(self.cut_tiles.values()) * 1.0,
"met_bill": 5 * int(self.read_bit(0xD7F1, 0)),
Expand Down

0 comments on commit 10025e9

Please sign in to comment.