diff --git a/pokemonred_puffer/data/map.py b/pokemonred_puffer/data/map.py index 28bb5ca..0257adb 100644 --- a/pokemonred_puffer/data/map.py +++ b/pokemonred_puffer/data/map.py @@ -287,17 +287,17 @@ class MapIds(Enum): MapIds.ROCKET_HIDEOUT_B3F: [[], ["EVENT_BEAT_ROCKET_HIDEOUT_GIOVANNI"]], MapIds.ROCKET_HIDEOUT_B4F: [[], ["EVENT_BEAT_ROCKET_HIDEOUT_GIOVANNI"]], MapIds.ROCKET_HIDEOUT_ELEVATOR: [[], ["EVENT_BEAT_ROCKET_HIDEOUT_GIOVANNI"]], - MapIds.SILPH_CO_1F: [[], ["EVENT_BEAT_SILPH_CO_GIOVANNI"]], - MapIds.SILPH_CO_2F: [[], ["EVENT_BEAT_SILPH_CO_GIOVANNI"]], - MapIds.SILPH_CO_3F: [[], ["EVENT_BEAT_SILPH_CO_GIOVANNI"]], - MapIds.SILPH_CO_4F: [[], ["EVENT_BEAT_SILPH_CO_GIOVANNI"]], - MapIds.SILPH_CO_5F: [[], ["EVENT_BEAT_SILPH_CO_GIOVANNI"]], - MapIds.SILPH_CO_6F: [[], ["EVENT_BEAT_SILPH_CO_GIOVANNI"]], - MapIds.SILPH_CO_7F: [[], ["EVENT_BEAT_SILPH_CO_GIOVANNI"]], - MapIds.SILPH_CO_8F: [[], ["EVENT_BEAT_SILPH_CO_GIOVANNI"]], - MapIds.SILPH_CO_9F: [[], ["EVENT_BEAT_SILPH_CO_GIOVANNI"]], - MapIds.SILPH_CO_10F: [[], ["EVENT_BEAT_SILPH_CO_GIOVANNI"]], - MapIds.SILPH_CO_11F: [[], ["EVENT_BEAT_SILPH_CO_GIOVANNI"]], + MapIds.SILPH_CO_1F: [[], ["BIT_GOT_LAPRAS", "EVENT_BEAT_SILPH_CO_GIOVANNI"]], + MapIds.SILPH_CO_2F: [[], ["BIT_GOT_LAPRAS", "EVENT_BEAT_SILPH_CO_GIOVANNI"]], + MapIds.SILPH_CO_3F: [[], ["BIT_GOT_LAPRAS", "EVENT_BEAT_SILPH_CO_GIOVANNI"]], + MapIds.SILPH_CO_4F: [[], ["BIT_GOT_LAPRAS", "EVENT_BEAT_SILPH_CO_GIOVANNI"]], + MapIds.SILPH_CO_5F: [[], ["BIT_GOT_LAPRAS", "EVENT_BEAT_SILPH_CO_GIOVANNI"]], + MapIds.SILPH_CO_6F: [[], ["BIT_GOT_LAPRAS", "EVENT_BEAT_SILPH_CO_GIOVANNI"]], + MapIds.SILPH_CO_7F: [[], ["BIT_GOT_LAPRAS", "EVENT_BEAT_SILPH_CO_GIOVANNI"]], + MapIds.SILPH_CO_8F: [[], ["BIT_GOT_LAPRAS", "EVENT_BEAT_SILPH_CO_GIOVANNI"]], + MapIds.SILPH_CO_9F: [[], ["BIT_GOT_LAPRAS", "EVENT_BEAT_SILPH_CO_GIOVANNI"]], + MapIds.SILPH_CO_10F: [[], ["BIT_GOT_LAPRAS", "EVENT_BEAT_SILPH_CO_GIOVANNI"]], + MapIds.SILPH_CO_11F: [[], ["BIT_GOT_LAPRAS", "EVENT_BEAT_SILPH_CO_GIOVANNI"]], MapIds.SILPH_CO_ELEVATOR: [[], ["EVENT_BEAT_SILPH_CO_GIOVANNI"]], MapIds.POKEMON_MANSION_1F: [[], ["HS_POKEMON_MANSION_B1F_ITEM_5"]], MapIds.POKEMON_MANSION_2F: [[], ["HS_POKEMON_MANSION_B1F_ITEM_5"]], diff --git a/pokemonred_puffer/data/status_flags.py b/pokemonred_puffer/data/status_flags.py deleted file mode 100644 index 7e5fa42..0000000 --- a/pokemonred_puffer/data/status_flags.py +++ /dev/null @@ -1,27 +0,0 @@ -from ctypes import LittleEndianStructure, Union, c_uint8 - -from pyboy import PyBoy - - -class StatusFlags1Bits(LittleEndianStructure): - _fields_ = [ - ("USING_STRENGTH_OUTSIDE_OF_BATTLE", c_uint8, 1), - ("IS_SURFING_ALLOWED", c_uint8, 1), - ("UNUSED_0", c_uint8, 1), - ("RECEIVED_OLD_ROD", c_uint8, 1), - ("RECEIVED_GOOD_ROD", c_uint8, 1), - ("RECEIVED_SUPER_ROD", c_uint8, 1), - ("GAVE_SAFFRON_GUARD_DRINK", c_uint8, 1), - ("UNUSED_2", c_uint8, 1), - ] - - -class StatusFlags1(Union): - _fields_ = [("b", StatusFlags1Bits), ("asbytes", c_uint8)] - - def __init__(self, emu: PyBoy): - super().__init__() - self.asbytes = (c_uint8)(emu.memory[emu.symbol_lookup("wStatusFlags1")[1]]) - - def get_bit(self, bit: str) -> bool: - return bool(getattr(self.b, bit)) diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index bb5f9d9..c83e4bd 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -48,7 +48,7 @@ SURF_SPECIES_IDS, TmHmMoves, ) -from pokemonred_puffer.data.status_flags import StatusFlags1 +from pokemonred_puffer.data.flags import Flags from pokemonred_puffer.global_map import GLOBAL_MAP_SHAPE, local_to_global PIXEL_VALUES = np.array([0, 85, 153, 255], dtype=np.uint8) @@ -202,8 +202,8 @@ def __init__(self, env_config: pufferlib.namespace): "speed": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint32), "special": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint32), "moves": spaces.Box(low=0, high=0xA4, shape=(6, 4), dtype=np.uint8), - # Add 3 for rival_3, game corner rocket and saffron guard - "events": spaces.Box(low=0, high=1, shape=(len(EVENTS) + 3,), dtype=np.uint8), + # Add 4 for rival_3, game corner rocket, saffron guard and lapras + "events": spaces.Box(low=0, high=1, shape=(len(EVENTS) + 4,), dtype=np.uint8), } if not self.skip_safari_zone: obs_dict["safari_steps"] = spaces.Box(low=0, high=502.0, shape=(1,), dtype=np.uint32) @@ -313,7 +313,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = # A bit of duplicate code. Blah. self.events = EventFlags(self.pyboy) self.missables = MissableFlags(self.pyboy) - self.status_flags_1 = StatusFlags1(self.pyboy) + self.flags = Flags(self.pyboy) self.party = PartyMons(self.pyboy) self.required_events = self.get_required_events() self.required_items = self.get_required_items() @@ -355,7 +355,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.events = EventFlags(self.pyboy) self.missables = MissableFlags(self.pyboy) - self.status_flags_1 = StatusFlags1(self.pyboy) + self.flags = Flags(self.pyboy) self.party = PartyMons(self.pyboy) self.update_pokedex() self.update_tm_hm_moves_obtained() @@ -621,7 +621,8 @@ def _get_obs(self): + [ self.read_m("wSSAnne2FCurScript") == 4, # rival 3 self.missables.get_missable("HS_GAME_CORNER_ROCKET"), # game corner rocket - self.status_flags_1.get_bit("GAVE_SAFFRON_GUARD_DRINK"), # saffron guard + self.flags.get_bit("BIT_GAVE_SAFFRON_GUARDS_DRINK"), # saffron guard + self.flags.get_bit("BIT_GOT_LAPRAS"), # got lapras ], dtype=np.uint8, ), @@ -679,7 +680,7 @@ def step(self, action): self.run_action_on_emulator(action) self.events = EventFlags(self.pyboy) self.missables = MissableFlags(self.pyboy) - self.status_flags_1 = StatusFlags1(self.pyboy) + self.flags = Flags(self.pyboy) self.party = PartyMons(self.pyboy) self.update_health() self.update_pokedex() @@ -1121,8 +1122,7 @@ def solve_strength_puzzle(self): if not self.disable_wild_encounters: self.setup_disable_wild_encounters() # Activate strength - _, status_flags_1 = self.pyboy.symbol_lookup("wStatusFlags1") - self.pyboy.memory[status_flags_1] |= 0b0000_0001 + self.flags.set_bit("BIT_STRENGTH_ACTIVE", 1) # Perform solution current_repel_steps = self.read_m("wRepelRemainingSteps") for step in steps: @@ -1218,7 +1218,7 @@ def next_elevator_floor(self): self.pyboy.tick(self.action_freq, render=self.animate_scripts) def insert_guard_drinks(self): - if not self.status_flags_1.get_bit("GAVE_SAFFRON_GUARD_DRINK") and MapIds( + if not self.flags.get_bit("BIT_GAVE_SAFFRON_GUARDS_DRINK") and MapIds( self.read_m("wCurMap") ) in [ MapIds.CELADON_MART_1F, @@ -1415,7 +1415,8 @@ def agent_stats(self, action): | { "rival3": int(self.read_m(0xD665) == 4), "game_corner_rocket": self.missables.get_missable("HS_GAME_CORNER_ROCKET"), - "saffron_guard": self.status_flags_1.get_bit("GAVE_SAFFRON_GUARD_DRINK"), + "saffron_guard": self.flags.get_bit("BIT_GAVE_SAFFRON_GUARDS_DRINK"), + "lapras": self.flags.get_bit("BIT_GOT_LAPRAS"), }, "required_items": {item.name: item.value in bag_item_ids for item in REQUIRED_ITEMS}, "useful_items": {item.name: item.value in bag_item_ids for item in USEFUL_ITEMS}, @@ -1696,11 +1697,8 @@ def get_required_events(self) -> set[str]: if self.missables.get_missable("HS_GAME_CORNER_ROCKET") else set() ) - | ( - {"saffron_guard"} - if self.status_flags_1.get_bit("GAVE_SAFFRON_GUARD_DRINK") - else set() - ) + | ({"saffron_guard"} if self.flags.get_bit("BIT_GAVE_SAFFRON_GUARDS_DRINK") else set()) + | ({"lapras"} if self.flags.get_bit("BIT_GOT_LAPRAS") else set()) ) def get_required_items(self) -> set[str]: @@ -1727,25 +1725,18 @@ def scale_map_id(self, map_n: int) -> float: map_id = MapIds(map_n) if map_id not in MAP_ID_COMPLETION_EVENTS: return False - after_events, until_events = MAP_ID_COMPLETION_EVENTS[map_id] + after, until = MAP_ID_COMPLETION_EVENTS[map_id] if all( - (event_or_missable.startswith("EVENT_") and self.events.get_event(event_or_missable)) - or ( - event_or_missable.startswith("HS_") - and self.missables.get_missable(event_or_missable) - ) - for event_or_missable in after_events + (item.startswith("EVENT_") and self.events.get_event(item)) + or (item.startswith("HS_") and self.missables.get_missable(item)) + or (item.startswith("BIT_") and self.flags.get_bit(item)) + for item in after ) and all( - ( - event_or_missable.startswith("EVENT_") - and not self.events.get_event(event_or_missable) - ) - or ( - event_or_missable.startswith("HS_") - and not self.missables.get_missable(event_or_missable) - ) - for event_or_missable in until_events + (item.startswith("EVENT_") and not self.events.get_event(item)) + or (item.startswith("HS_") and not self.missables.get_missable(item)) + or (item.startswith("BIT_") and not self.flags.get_bit(item)) + for item in until ): return True return False diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index a848004..6dad126 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -269,7 +269,9 @@ def get_game_state_reward(self): "game_corner_rocket": self.reward_config["required_event"] * float(self.missables.get_missable("HS_GAME_CORNER_ROCKET")), "saffron_guard": self.reward_config["required_event"] - * float(self.status_flags_1.get_bit("GAVE_SAFFRON_GUARD_DRINK")), + * float(self.flags.get_bit("BIT_GAVE_SAFFRON_GUARDS_DRINK")), + "lapras": self.reward_config["required_event"] + * float(self.flags.get_bit("BIT_GOT_LAPRAS")), } | { event: self.reward_config["required_event"] * float(self.events.get_event(event)) @@ -327,7 +329,9 @@ def get_game_state_reward(self): "game_corner_rocket": self.reward_config["required_event"] * float(self.missables.get_missable("HS_GAME_CORNER_ROCKET")), "saffron_guard": self.reward_config["required_event"] - * float(self.status_flags_1.get_bit("GAVE_SAFFRON_GUARD_DRINK")), + * float(self.flags.get_bit("BIT_GAVE_SAFFRON_GUARDS_DRINK")), + "lapras": self.reward_config["required_event"] + * float(self.flags.get_bit("BIT_GOT_LAPRAS")), "a_press": len(self.a_press) * self.reward_config["a_press"], "warps": len(self.seen_warps) * self.reward_config["explore_warps"], "use_surf": self.reward_config["use_surf"] * self.use_surf, @@ -395,7 +399,9 @@ def get_game_state_reward(self): "game_corner_rocket": self.reward_config["required_event"] * float(self.missables.get_missable("HS_GAME_CORNER_ROCKET")), "saffron_guard": self.reward_config["required_event"] - * float(self.status_flags_1.get_bit("GAVE_SAFFRON_GUARD_DRINK")), + * float(self.flags.get_bit("BIT_GAVE_SAFFRON_GUARDS_DRINK")), + "lapras": self.reward_config["required_event"] + * float(self.flags.get_bit("BIT_GOT_LAPRAS")), "a_press": len(self.a_press) * self.reward_config["a_press"], "warps": len(self.seen_warps) * self.reward_config["explore_warps"], "use_surf": self.reward_config["use_surf"] * self.use_surf,