From 467df24bfc2626512493d87a07f265cd7a73ba28 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Tue, 25 Jun 2024 03:19:19 -0400 Subject: [PATCH] Add game corner rocket as an event --- pokemonred_puffer/data/events.py | 4 +- pokemonred_puffer/data/missable_objects.py | 249 ++++++++++++++++++ pokemonred_puffer/environment.py | 14 +- .../policies/multi_convolutional.py | 2 + pokemonred_puffer/rewards/baseline.py | 2 + 5 files changed, 268 insertions(+), 3 deletions(-) create mode 100644 pokemonred_puffer/data/missable_objects.py diff --git a/pokemonred_puffer/data/events.py b/pokemonred_puffer/data/events.py index fdf46ba..c0dd0ed 100644 --- a/pokemonred_puffer/data/events.py +++ b/pokemonred_puffer/data/events.py @@ -7,7 +7,7 @@ MUSEUM_TICKET = (0xD754, 0) -class Flags_bits(LittleEndianStructure): +class EventFlagsBits(LittleEndianStructure): _fields_ = [ ("EVENT_FOLLOWED_OAK_INTO_LAB", c_uint8, 1), ("EVENT_001", c_uint8, 1), @@ -2573,7 +2573,7 @@ class Flags_bits(LittleEndianStructure): class EventFlags(Union): - _fields_ = [("b", Flags_bits), ("asbytes", c_uint8 * 320)] + _fields_ = [("b", EventFlagsBits), ("asbytes", c_uint8 * 320)] def __init__(self, emu: PyBoy): super().__init__() diff --git a/pokemonred_puffer/data/missable_objects.py b/pokemonred_puffer/data/missable_objects.py new file mode 100644 index 0000000..6324d2b --- /dev/null +++ b/pokemonred_puffer/data/missable_objects.py @@ -0,0 +1,249 @@ +# Pad to 32 Bytes +from ctypes import LittleEndianStructure, Union, c_uint8 + +from pyboy import PyBoy + + +class MissableFlagsBits(LittleEndianStructure): + _fields_ = [ + ("HS_PALLET_TOWN_OAK", c_uint8, 1), + ("HS_LYING_OLD_MAN", c_uint8, 1), + ("HS_OLD_MAN", c_uint8, 1), + ("HS_MUSEUM_GUY", c_uint8, 1), + ("HS_GYM_GUY", c_uint8, 1), + ("HS_CERULEAN_RIVAL", c_uint8, 1), + ("HS_CERULEAN_ROCKET", c_uint8, 1), + ("HS_CERULEAN_GUARD_1", c_uint8, 1), + ("HS_CERULEAN_CAVE_GUY", c_uint8, 1), + ("HS_CERULEAN_GUARD_2", c_uint8, 1), + ("HS_SAFFRON_CITY_1", c_uint8, 1), + ("HS_SAFFRON_CITY_2", c_uint8, 1), + ("HS_SAFFRON_CITY_3", c_uint8, 1), + ("HS_SAFFRON_CITY_4", c_uint8, 1), + ("HS_SAFFRON_CITY_5", c_uint8, 1), + ("HS_SAFFRON_CITY_6", c_uint8, 1), + ("HS_SAFFRON_CITY_7", c_uint8, 1), + ("HS_SAFFRON_CITY_8", c_uint8, 1), + ("HS_SAFFRON_CITY_9", c_uint8, 1), + ("HS_SAFFRON_CITY_A", c_uint8, 1), + ("HS_SAFFRON_CITY_B", c_uint8, 1), + ("HS_SAFFRON_CITY_C", c_uint8, 1), + ("HS_SAFFRON_CITY_D", c_uint8, 1), + ("HS_SAFFRON_CITY_E", c_uint8, 1), + ("HS_SAFFRON_CITY_F", c_uint8, 1), + ("HS_ROUTE_2_ITEM_1", c_uint8, 1), + ("HS_ROUTE_2_ITEM_2", c_uint8, 1), + ("HS_ROUTE_4_ITEM", c_uint8, 1), + ("HS_ROUTE_9_ITEM", c_uint8, 1), + ("HS_ROUTE_12_SNORLAX", c_uint8, 1), + ("HS_ROUTE_12_ITEM_1", c_uint8, 1), + ("HS_ROUTE_12_ITEM_2", c_uint8, 1), + ("HS_ROUTE_15_ITEM", c_uint8, 1), + ("HS_ROUTE_16_SNORLAX", c_uint8, 1), + ("HS_ROUTE_22_RIVAL_1", c_uint8, 1), + ("HS_ROUTE_22_RIVAL_2", c_uint8, 1), + ("HS_NUGGET_BRIDGE_GUY", c_uint8, 1), + ("HS_ROUTE_24_ITEM", c_uint8, 1), + ("HS_ROUTE_25_ITEM", c_uint8, 1), + ("HS_DAISY_SITTING", c_uint8, 1), + ("HS_DAISY_WALKING", c_uint8, 1), + ("HS_TOWN_MAP", c_uint8, 1), + ("HS_OAKS_LAB_RIVAL", c_uint8, 1), + ("HS_STARTER_BALL_1", c_uint8, 1), + ("HS_STARTER_BALL_2", c_uint8, 1), + ("HS_STARTER_BALL_3", c_uint8, 1), + ("HS_OAKS_LAB_OAK_1", c_uint8, 1), + ("HS_POKEDEX_1", c_uint8, 1), + ("HS_POKEDEX_2", c_uint8, 1), + ("HS_OAKS_LAB_OAK_2", c_uint8, 1), + ("HS_VIRIDIAN_GYM_GIOVANNI", c_uint8, 1), + ("HS_VIRIDIAN_GYM_ITEM", c_uint8, 1), + ("HS_OLD_AMBER", c_uint8, 1), + ("HS_CERULEAN_CAVE_1F_ITEM_1", c_uint8, 1), + ("HS_CERULEAN_CAVE_1F_ITEM_2", c_uint8, 1), + ("HS_CERULEAN_CAVE_1F_ITEM_3", c_uint8, 1), + ("HS_POKEMON_TOWER_2F_RIVAL", c_uint8, 1), + ("HS_POKEMON_TOWER_3F_ITEM", c_uint8, 1), + ("HS_POKEMON_TOWER_4F_ITEM_1", c_uint8, 1), + ("HS_POKEMON_TOWER_4F_ITEM_2", c_uint8, 1), + ("HS_POKEMON_TOWER_4F_ITEM_3", c_uint8, 1), + ("HS_POKEMON_TOWER_5F_ITEM", c_uint8, 1), + ("HS_POKEMON_TOWER_6F_ITEM_1", c_uint8, 1), + ("HS_POKEMON_TOWER_6F_ITEM_2", c_uint8, 1), + ("HS_POKEMON_TOWER_7F_ROCKET_1", c_uint8, 1), + ("HS_POKEMON_TOWER_7F_ROCKET_2", c_uint8, 1), + ("HS_POKEMON_TOWER_7F_ROCKET_3", c_uint8, 1), + ("HS_POKEMON_TOWER_7F_MR_FUJI", c_uint8, 1), + ("HS_MR_FUJIS_HOUSE_MR_FUJI", c_uint8, 1), + ("HS_CELADON_MANSION_EEVEE_GIFT", c_uint8, 1), + ("HS_GAME_CORNER_ROCKET", c_uint8, 1), + ("HS_WARDENS_HOUSE_ITEM", c_uint8, 1), + ("HS_POKEMON_MANSION_1F_ITEM_1", c_uint8, 1), + ("HS_POKEMON_MANSION_1F_ITEM_2", c_uint8, 1), + ("HS_FIGHTING_DOJO_GIFT_1", c_uint8, 1), + ("HS_FIGHTING_DOJO_GIFT_2", c_uint8, 1), + ("HS_SILPH_CO_1F_RECEPTIONIST", c_uint8, 1), + ("HS_VOLTORB_1", c_uint8, 1), + ("HS_VOLTORB_2", c_uint8, 1), + ("HS_VOLTORB_3", c_uint8, 1), + ("HS_ELECTRODE_1", c_uint8, 1), + ("HS_VOLTORB_4", c_uint8, 1), + ("HS_VOLTORB_5", c_uint8, 1), + ("HS_ELECTRODE_2", c_uint8, 1), + ("HS_VOLTORB_6", c_uint8, 1), + ("HS_ZAPDOS", c_uint8, 1), + ("HS_POWER_PLANT_ITEM_1", c_uint8, 1), + ("HS_POWER_PLANT_ITEM_2", c_uint8, 1), + ("HS_POWER_PLANT_ITEM_3", c_uint8, 1), + ("HS_POWER_PLANT_ITEM_4", c_uint8, 1), + ("HS_POWER_PLANT_ITEM_5", c_uint8, 1), + ("HS_MOLTRES", c_uint8, 1), + ("HS_VICTORY_ROAD_2F_ITEM_1", c_uint8, 1), + ("HS_VICTORY_ROAD_2F_ITEM_2", c_uint8, 1), + ("HS_VICTORY_ROAD_2F_ITEM_3", c_uint8, 1), + ("HS_VICTORY_ROAD_2F_ITEM_4", c_uint8, 1), + ("HS_VICTORY_ROAD_2F_BOULDER", c_uint8, 1), + ("HS_BILL_POKEMON", c_uint8, 1), + ("HS_BILL_1", c_uint8, 1), + ("HS_BILL_2", c_uint8, 1), + ("HS_VIRIDIAN_FOREST_ITEM_1", c_uint8, 1), + ("HS_VIRIDIAN_FOREST_ITEM_2", c_uint8, 1), + ("HS_VIRIDIAN_FOREST_ITEM_3", c_uint8, 1), + ("HS_MT_MOON_1F_ITEM_1", c_uint8, 1), + ("HS_MT_MOON_1F_ITEM_2", c_uint8, 1), + ("HS_MT_MOON_1F_ITEM_3", c_uint8, 1), + ("HS_MT_MOON_1F_ITEM_4", c_uint8, 1), + ("HS_MT_MOON_1F_ITEM_5", c_uint8, 1), + ("HS_MT_MOON_1F_ITEM_6", c_uint8, 1), + ("HS_MT_MOON_B2F_FOSSIL_1", c_uint8, 1), + ("HS_MT_MOON_B2F_FOSSIL_2", c_uint8, 1), + ("HS_MT_MOON_B2F_ITEM_1", c_uint8, 1), + ("HS_MT_MOON_B2F_ITEM_2", c_uint8, 1), + ("HS_SS_ANNE_2F_RIVAL", c_uint8, 1), + ("HS_SS_ANNE_1F_ROOMS_ITEM", c_uint8, 1), + ("HS_SS_ANNE_2F_ROOMS_ITEM_1", c_uint8, 1), + ("HS_SS_ANNE_2F_ROOMS_ITEM_2", c_uint8, 1), + ("HS_SS_ANNE_B1F_ROOMS_ITEM_1", c_uint8, 1), + ("HS_SS_ANNE_B1F_ROOMS_ITEM_2", c_uint8, 1), + ("HS_SS_ANNE_B1F_ROOMS_ITEM_3", c_uint8, 1), + ("HS_VICTORY_ROAD_3F_ITEM_1", c_uint8, 1), + ("HS_VICTORY_ROAD_3F_ITEM_2", c_uint8, 1), + ("HS_VICTORY_ROAD_3F_BOULDER", c_uint8, 1), + ("HS_ROCKET_HIDEOUT_B1F_ITEM_1", c_uint8, 1), + ("HS_ROCKET_HIDEOUT_B1F_ITEM_2", c_uint8, 1), + ("HS_ROCKET_HIDEOUT_B2F_ITEM_1", c_uint8, 1), + ("HS_ROCKET_HIDEOUT_B2F_ITEM_2", c_uint8, 1), + ("HS_ROCKET_HIDEOUT_B2F_ITEM_3", c_uint8, 1), + ("HS_ROCKET_HIDEOUT_B2F_ITEM_4", c_uint8, 1), + ("HS_ROCKET_HIDEOUT_B3F_ITEM_1", c_uint8, 1), + ("HS_ROCKET_HIDEOUT_B3F_ITEM_2", c_uint8, 1), + ("HS_ROCKET_HIDEOUT_B4F_GIOVANNI", c_uint8, 1), + ("HS_ROCKET_HIDEOUT_B4F_ITEM_1", c_uint8, 1), + ("HS_ROCKET_HIDEOUT_B4F_ITEM_2", c_uint8, 1), + ("HS_ROCKET_HIDEOUT_B4F_ITEM_3", c_uint8, 1), + ("HS_ROCKET_HIDEOUT_B4F_ITEM_4", c_uint8, 1), + ("HS_ROCKET_HIDEOUT_B4F_ITEM_5", c_uint8, 1), + ("HS_SILPH_CO_2F_1", c_uint8, 1), + ("HS_SILPH_CO_2F_2", c_uint8, 1), + ("HS_SILPH_CO_2F_3", c_uint8, 1), + ("HS_SILPH_CO_2F_4", c_uint8, 1), + ("HS_SILPH_CO_2F_5", c_uint8, 1), + ("HS_SILPH_CO_3F_1", c_uint8, 1), + ("HS_SILPH_CO_3F_2", c_uint8, 1), + ("HS_SILPH_CO_3F_ITEM", c_uint8, 1), + ("HS_SILPH_CO_4F_1", c_uint8, 1), + ("HS_SILPH_CO_4F_2", c_uint8, 1), + ("HS_SILPH_CO_4F_3", c_uint8, 1), + ("HS_SILPH_CO_4F_ITEM_1", c_uint8, 1), + ("HS_SILPH_CO_4F_ITEM_2", c_uint8, 1), + ("HS_SILPH_CO_4F_ITEM_3", c_uint8, 1), + ("HS_SILPH_CO_5F_1", c_uint8, 1), + ("HS_SILPH_CO_5F_2", c_uint8, 1), + ("HS_SILPH_CO_5F_3", c_uint8, 1), + ("HS_SILPH_CO_5F_4", c_uint8, 1), + ("HS_SILPH_CO_5F_ITEM_1", c_uint8, 1), + ("HS_SILPH_CO_5F_ITEM_2", c_uint8, 1), + ("HS_SILPH_CO_5F_ITEM_3", c_uint8, 1), + ("HS_SILPH_CO_6F_1", c_uint8, 1), + ("HS_SILPH_CO_6F_2", c_uint8, 1), + ("HS_SILPH_CO_6F_3", c_uint8, 1), + ("HS_SILPH_CO_6F_ITEM_1", c_uint8, 1), + ("HS_SILPH_CO_6F_ITEM_2", c_uint8, 1), + ("HS_SILPH_CO_7F_1", c_uint8, 1), + ("HS_SILPH_CO_7F_2", c_uint8, 1), + ("HS_SILPH_CO_7F_3", c_uint8, 1), + ("HS_SILPH_CO_7F_4", c_uint8, 1), + ("HS_SILPH_CO_7F_RIVAL", c_uint8, 1), + ("HS_SILPH_CO_7F_ITEM_1", c_uint8, 1), + ("HS_SILPH_CO_7F_ITEM_2", c_uint8, 1), + ("HS_SILPH_CO_7F_8", c_uint8, 1), + ("HS_SILPH_CO_8F_1", c_uint8, 1), + ("HS_SILPH_CO_8F_2", c_uint8, 1), + ("HS_SILPH_CO_8F_3", c_uint8, 1), + ("HS_SILPH_CO_9F_1", c_uint8, 1), + ("HS_SILPH_CO_9F_2", c_uint8, 1), + ("HS_SILPH_CO_9F_3", c_uint8, 1), + ("HS_SILPH_CO_10F_1", c_uint8, 1), + ("HS_SILPH_CO_10F_2", c_uint8, 1), + ("HS_SILPH_CO_10F_3", c_uint8, 1), + ("HS_SILPH_CO_10F_ITEM_1", c_uint8, 1), + ("HS_SILPH_CO_10F_ITEM_2", c_uint8, 1), + ("HS_SILPH_CO_10F_ITEM_3", c_uint8, 1), + ("HS_SILPH_CO_11F_1", c_uint8, 1), + ("HS_SILPH_CO_11F_2", c_uint8, 1), + ("HS_SILPH_CO_11F_3", c_uint8, 1), + ("HS_UNUSED_MAP_F4_1", c_uint8, 1), + ("HS_POKEMON_MANSION_2F_ITEM", c_uint8, 1), + ("HS_POKEMON_MANSION_3F_ITEM_1", c_uint8, 1), + ("HS_POKEMON_MANSION_3F_ITEM_2", c_uint8, 1), + ("HS_POKEMON_MANSION_B1F_ITEM_1", c_uint8, 1), + ("HS_POKEMON_MANSION_B1F_ITEM_2", c_uint8, 1), + ("HS_POKEMON_MANSION_B1F_ITEM_3", c_uint8, 1), + ("HS_POKEMON_MANSION_B1F_ITEM_4", c_uint8, 1), + ("HS_POKEMON_MANSION_B1F_ITEM_5", c_uint8, 1), + ("HS_SAFARI_ZONE_EAST_ITEM_1", c_uint8, 1), + ("HS_SAFARI_ZONE_EAST_ITEM_2", c_uint8, 1), + ("HS_SAFARI_ZONE_EAST_ITEM_3", c_uint8, 1), + ("HS_SAFARI_ZONE_EAST_ITEM_4", c_uint8, 1), + ("HS_SAFARI_ZONE_NORTH_ITEM_1", c_uint8, 1), + ("HS_SAFARI_ZONE_NORTH_ITEM_2", c_uint8, 1), + ("HS_SAFARI_ZONE_WEST_ITEM_1", c_uint8, 1), + ("HS_SAFARI_ZONE_WEST_ITEM_2", c_uint8, 1), + ("HS_SAFARI_ZONE_WEST_ITEM_3", c_uint8, 1), + ("HS_SAFARI_ZONE_WEST_ITEM_4", c_uint8, 1), + ("HS_SAFARI_ZONE_CENTER_ITEM", c_uint8, 1), + ("HS_CERULEAN_CAVE_2F_ITEM_1", c_uint8, 1), + ("HS_CERULEAN_CAVE_2F_ITEM_2", c_uint8, 1), + ("HS_CERULEAN_CAVE_2F_ITEM_3", c_uint8, 1), + ("HS_MEWTWO", c_uint8, 1), + ("HS_CERULEAN_CAVE_B1F_ITEM_1", c_uint8, 1), + ("HS_CERULEAN_CAVE_B1F_ITEM_2", c_uint8, 1), + ("HS_VICTORY_ROAD_1F_ITEM_1", c_uint8, 1), + ("HS_VICTORY_ROAD_1F_ITEM_2", c_uint8, 1), + ("HS_CHAMPIONS_ROOM_OAK", c_uint8, 1), + ("HS_SEAFOAM_ISLANDS_1F_BOULDER_1", c_uint8, 1), + ("HS_SEAFOAM_ISLANDS_1F_BOULDER_2", c_uint8, 1), + ("HS_SEAFOAM_ISLANDS_B1F_BOULDER_1", c_uint8, 1), + ("HS_SEAFOAM_ISLANDS_B1F_BOULDER_2", c_uint8, 1), + ("HS_SEAFOAM_ISLANDS_B2F_BOULDER_1", c_uint8, 1), + ("HS_SEAFOAM_ISLANDS_B2F_BOULDER_2", c_uint8, 1), + ("HS_SEAFOAM_ISLANDS_B3F_BOULDER_1", c_uint8, 1), + ("HS_SEAFOAM_ISLANDS_B3F_BOULDER_2", c_uint8, 1), + ("HS_SEAFOAM_ISLANDS_B3F_BOULDER_3", c_uint8, 1), + ("HS_SEAFOAM_ISLANDS_B3F_BOULDER_4", c_uint8, 1), + ("HS_SEAFOAM_ISLANDS_B4F_BOULDER_1", c_uint8, 1), + ("HS_SEAFOAM_ISLANDS_B4F_BOULDER_2", c_uint8, 1), + ("HS_ARTICUNO", c_uint8, 1), + ] + + +class MissableFlags(Union): + # These missable flags is a 32 byte object + _fields_ = [("b", MissableFlagsBits), ("asbytes", c_uint8 * 32)] + + def __init__(self, emu: PyBoy): + super().__init__() + self.asbytes = (c_uint8 * 32)(*emu.memory[0xD5A6 : 0xD5A6 + 32]) + + def get_missable(self, missable: str) -> bool: + return bool(getattr(self.b, missable)) diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 4c2a38a..fb43ac5 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -32,6 +32,7 @@ USEFUL_ITEMS, Items, ) +from pokemonred_puffer.data.missable_objects import MissableFlags from pokemonred_puffer.data.strength_puzzles import STRENGTH_SOLUTIONS from pokemonred_puffer.data.tilesets import Tilesets from pokemonred_puffer.data.tm_hm import ( @@ -166,6 +167,8 @@ def __init__(self, env_config: pufferlib.namespace): low=0, high=max(Items._value2member_map_.keys()), shape=(20,), dtype=np.uint8 ), "bag_quantity": spaces.Box(low=0, high=100, shape=(20,), dtype=np.uint8), + "rival_3": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), + "game_corner_rocket": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), } | { event: spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8) for event in REQUIRED_EVENTS @@ -290,6 +293,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.reset_mem() self.events = EventFlags(self.pyboy) + self.missables = MissableFlags(self.pyboy) self.update_pokedex() self.update_tm_hm_moves_obtained() self.taught_cut = self.check_if_party_has_hm(0xF) @@ -515,6 +519,10 @@ def _get_obs(self): "wJoyIgnore": np.array(self.read_m("wJoyIgnore"), dtype=np.uint8), "bag_items": bag[::2].copy(), "bag_quantity": bag[1::2].copy(), + "rival_3": np.array(self.read_m("wSSAnne2FCurScript") == 4, dtype=np.uint8), + "game_corner_rocket": np.array( + self.missables.get_missable("HS_GAME_CORNER_ROCKET"), dtype=np.uint8 + ), } | {event: np.array(self.events.get_event(event)) for event in REQUIRED_EVENTS} ) @@ -554,6 +562,7 @@ def step(self, action): self.run_action_on_emulator(action) self.events = EventFlags(self.pyboy) + self.missables = MissableFlags(self.pyboy) self.update_seen_coords() self.update_health() self.update_pokedex() @@ -1127,7 +1136,10 @@ def agent_stats(self, action): } | {f"badge_{i+1}": bool(badges & (1 << i)) for i in range(8)}, "events": {event: self.events.get_event(event) for event in REQUIRED_EVENTS} - | {"rival3": int(self.read_m(0xD665) == 4)}, + | { + "rival3": int(self.read_m(0xD665) == 4), + "game_corner_rocket": self.missables.get_missable("HS_GAME_CORNER_ROCKET"), + }, "required_items": {item.name: item.value in bag_item_ids for item in REQUIRED_ITEMS}, "useful_items": {item.name: item.value in bag_item_ids for item in USEFUL_ITEMS}, "reward": self.get_game_state_reward(), diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index 706c9b9..9eb614f 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -181,6 +181,8 @@ def encode_observations(self, observations): blackout_map_id.squeeze(1), observations["wJoyIgnore"].float(), items.flatten(start_dim=1), + observations["rival_3"].float(), + observations["game_corner_rocket"].float(), ) + tuple(observations[event].float() for event in REQUIRED_EVENTS), dim=-1, diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index 029f9c0..6b58c9c 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -258,6 +258,8 @@ def get_game_state_reward(self): "seen_action_bag_menu": self.seen_action_bag_menu * self.reward_config["seen_action_bag_menu"], "rival3": self.reward_config["event"] * int(self.read_m("wSSAnne2FCurScript") == 4), + "game_corner_rocket": self.reward_config["event"] + * float(self.missables.get_missable("HS_GAME_CORNER_ROCKET")), } | { event: self.reward_config["required_event"] * float(self.events.get_event(event))