diff --git a/pokemonred_puffer/data/events.py b/pokemonred_puffer/data/events.py index 9a4c4a8..8878155 100644 --- a/pokemonred_puffer/data/events.py +++ b/pokemonred_puffer/data/events.py @@ -1,5 +1,6 @@ from ctypes import c_uint8, LittleEndianStructure, Union import re +from typing import Iterator from pyboy import PyBoy @@ -2589,6 +2590,13 @@ def get_event(self, event_name: str) -> int: """ return getattr(self.b, event_name) + def get_events(self, event_names: Iterator[str]) -> Iterator[int]: + """ + 1 if true, 0 if false + """ + for event_name in event_names: + yield getattr(self.b, event_name) + def set_event(self, event_name: str, value: bool): # This is O(N) but it's so rare that I'm not too worried about it idx = [x[0] for x in self.b._fields_].index(event_name) diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 03b5c54..47f3cdd 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -617,14 +617,18 @@ def _get_obs(self): "speed": np.array([self.party[i].Speed for i in range(6)], dtype=np.uint32), "special": np.array([self.party[i].Special for i in range(6)], dtype=np.uint32), "moves": np.array([self.party[i].Moves for i in range(6)], dtype=np.uint8), - "events": np.array( - [self.events.get_event(event) for event in EVENTS] - + [ - self.read_m("wSSAnne2FCurScript") == 4, # rival 3 - self.missables.get_missable("HS_GAME_CORNER_ROCKET"), # game corner rocket - self.flags.get_bit("BIT_GAVE_SAFFRON_GUARDS_DRINK"), # saffron guard - self.flags.get_bit("BIT_GOT_LAPRAS"), # got lapras - ], + "events": np.concatenate( + ( + np.fromiter(self.events.get_events(EVENTS), dtype=np.uint8), + [ + self.read_m("wSSAnne2FCurScript") == 4, # rival 3 + self.missables.get_missable( + "HS_GAME_CORNER_ROCKET" + ), # game corner rocket + self.flags.get_bit("BIT_GAVE_SAFFRON_GUARDS_DRINK"), # saffron guard + self.flags.get_bit("BIT_GOT_LAPRAS"), # got lapras + ], + ), dtype=np.uint8, ), } @@ -1613,9 +1617,13 @@ def update_health(self): def update_pokedex(self): # TODO: Make a hook - size = 0xD30A - 0xD2F7 - caught_mem = self.pyboy.memory[0xD2F7 : 0xD2F7 + size] - seen_mem = self.pyboy.memory[0xD30A : 0xD30A + size] + _, wPokedexOwned = self.pyboy.symbol_lookup("wPokedexOwned") + _, wPokedexOwnedEnd = self.pyboy.symbol_lookup("wPokedexOwnedEnd") + _, wPokedexSeen = self.pyboy.symbol_lookup("wPokedexSeen") + _, wPokedexSeenEnd = self.pyboy.symbol_lookup("wPokedexSeenEnd") + + caught_mem = self.pyboy.memory[wPokedexOwned:wPokedexOwnedEnd] + seen_mem = self.pyboy.memory[wPokedexSeen:wPokedexSeenEnd] self.caught_pokemon = np.unpackbits(np.array(caught_mem, dtype=np.uint8)) self.seen_pokemon = np.unpackbits(np.array(seen_mem, dtype=np.uint8)) @@ -1709,7 +1717,7 @@ def get_levels_reward(self): def get_required_events(self) -> set[str]: return ( - {event for event in REQUIRED_EVENTS if self.events.get_event(event)} + set(self.events.get_events(REQUIRED_EVENTS)) | ({"rival3"} if (self.read_m("wSSAnne2FCurScript") == 4) else set()) | ( {"game_corner_rocket"} diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index cb047fa..43d7844 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -376,20 +376,22 @@ def get_game_state_reward(self): return ( { "event": self.reward_config["event"] * self.update_max_event_rew(), - "seen_pokemon": self.reward_config["seen_pokemon"] * sum(self.seen_pokemon), - "caught_pokemon": self.reward_config["caught_pokemon"] * sum(self.caught_pokemon), - "moves_obtained": self.reward_config["moves_obtained"] * sum(self.moves_obtained), + "seen_pokemon": self.reward_config["seen_pokemon"] * np.sum(self.seen_pokemon), + "caught_pokemon": self.reward_config["caught_pokemon"] + * np.sum(self.caught_pokemon), + "moves_obtained": self.reward_config["moves_obtained"] + * np.sum(self.moves_obtained), "hm_count": self.reward_config["hm_count"] * self.get_hm_count(), "level": self.reward_config["level"] * self.get_levels_reward(), "badges": self.reward_config["badges"] * self.get_badges(), - "cut_coords": self.reward_config["cut_coords"] * sum(self.cut_coords.values()), - "cut_tiles": self.reward_config["cut_tiles"] * sum(self.cut_tiles.values()), + "cut_coords": self.reward_config["cut_coords"] * np.sum(self.cut_coords.values()), + "cut_tiles": self.reward_config["cut_tiles"] * np.sum(self.cut_tiles.values()), "start_menu": self.reward_config["start_menu"] * self.seen_start_menu, "pokemon_menu": self.reward_config["pokemon_menu"] * self.seen_pokemon_menu, "stats_menu": self.reward_config["stats_menu"] * self.seen_stats_menu, "bag_menu": self.reward_config["bag_menu"] * self.seen_bag_menu, - "explore_hidden_objs": sum(self.seen_hidden_objs.values()), - "explore_signs": sum(self.seen_signs.values()) + "explore_hidden_objs": np.sum(self.seen_hidden_objs.values()), + "explore_signs": np.sum(self.seen_signs.values()) * self.reward_config["explore_signs"], "seen_action_bag_menu": self.seen_action_bag_menu * self.reward_config["seen_action_bag_menu"],