Skip to content

Commit

Permalink
Add game corner rocket as an event
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 25, 2024
1 parent 1b756f5 commit 467df24
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 3 deletions.
4 changes: 2 additions & 2 deletions pokemonred_puffer/data/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
MUSEUM_TICKET = (0xD754, 0)


class Flags_bits(LittleEndianStructure):
class EventFlagsBits(LittleEndianStructure):
_fields_ = [
("EVENT_FOLLOWED_OAK_INTO_LAB", c_uint8, 1),
("EVENT_001", c_uint8, 1),
Expand Down Expand Up @@ -2573,7 +2573,7 @@ class Flags_bits(LittleEndianStructure):


class EventFlags(Union):
_fields_ = [("b", Flags_bits), ("asbytes", c_uint8 * 320)]
_fields_ = [("b", EventFlagsBits), ("asbytes", c_uint8 * 320)]

def __init__(self, emu: PyBoy):
super().__init__()
Expand Down
249 changes: 249 additions & 0 deletions pokemonred_puffer/data/missable_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
# Pad to 32 Bytes
from ctypes import LittleEndianStructure, Union, c_uint8

from pyboy import PyBoy


class MissableFlagsBits(LittleEndianStructure):
_fields_ = [
("HS_PALLET_TOWN_OAK", c_uint8, 1),
("HS_LYING_OLD_MAN", c_uint8, 1),
("HS_OLD_MAN", c_uint8, 1),
("HS_MUSEUM_GUY", c_uint8, 1),
("HS_GYM_GUY", c_uint8, 1),
("HS_CERULEAN_RIVAL", c_uint8, 1),
("HS_CERULEAN_ROCKET", c_uint8, 1),
("HS_CERULEAN_GUARD_1", c_uint8, 1),
("HS_CERULEAN_CAVE_GUY", c_uint8, 1),
("HS_CERULEAN_GUARD_2", c_uint8, 1),
("HS_SAFFRON_CITY_1", c_uint8, 1),
("HS_SAFFRON_CITY_2", c_uint8, 1),
("HS_SAFFRON_CITY_3", c_uint8, 1),
("HS_SAFFRON_CITY_4", c_uint8, 1),
("HS_SAFFRON_CITY_5", c_uint8, 1),
("HS_SAFFRON_CITY_6", c_uint8, 1),
("HS_SAFFRON_CITY_7", c_uint8, 1),
("HS_SAFFRON_CITY_8", c_uint8, 1),
("HS_SAFFRON_CITY_9", c_uint8, 1),
("HS_SAFFRON_CITY_A", c_uint8, 1),
("HS_SAFFRON_CITY_B", c_uint8, 1),
("HS_SAFFRON_CITY_C", c_uint8, 1),
("HS_SAFFRON_CITY_D", c_uint8, 1),
("HS_SAFFRON_CITY_E", c_uint8, 1),
("HS_SAFFRON_CITY_F", c_uint8, 1),
("HS_ROUTE_2_ITEM_1", c_uint8, 1),
("HS_ROUTE_2_ITEM_2", c_uint8, 1),
("HS_ROUTE_4_ITEM", c_uint8, 1),
("HS_ROUTE_9_ITEM", c_uint8, 1),
("HS_ROUTE_12_SNORLAX", c_uint8, 1),
("HS_ROUTE_12_ITEM_1", c_uint8, 1),
("HS_ROUTE_12_ITEM_2", c_uint8, 1),
("HS_ROUTE_15_ITEM", c_uint8, 1),
("HS_ROUTE_16_SNORLAX", c_uint8, 1),
("HS_ROUTE_22_RIVAL_1", c_uint8, 1),
("HS_ROUTE_22_RIVAL_2", c_uint8, 1),
("HS_NUGGET_BRIDGE_GUY", c_uint8, 1),
("HS_ROUTE_24_ITEM", c_uint8, 1),
("HS_ROUTE_25_ITEM", c_uint8, 1),
("HS_DAISY_SITTING", c_uint8, 1),
("HS_DAISY_WALKING", c_uint8, 1),
("HS_TOWN_MAP", c_uint8, 1),
("HS_OAKS_LAB_RIVAL", c_uint8, 1),
("HS_STARTER_BALL_1", c_uint8, 1),
("HS_STARTER_BALL_2", c_uint8, 1),
("HS_STARTER_BALL_3", c_uint8, 1),
("HS_OAKS_LAB_OAK_1", c_uint8, 1),
("HS_POKEDEX_1", c_uint8, 1),
("HS_POKEDEX_2", c_uint8, 1),
("HS_OAKS_LAB_OAK_2", c_uint8, 1),
("HS_VIRIDIAN_GYM_GIOVANNI", c_uint8, 1),
("HS_VIRIDIAN_GYM_ITEM", c_uint8, 1),
("HS_OLD_AMBER", c_uint8, 1),
("HS_CERULEAN_CAVE_1F_ITEM_1", c_uint8, 1),
("HS_CERULEAN_CAVE_1F_ITEM_2", c_uint8, 1),
("HS_CERULEAN_CAVE_1F_ITEM_3", c_uint8, 1),
("HS_POKEMON_TOWER_2F_RIVAL", c_uint8, 1),
("HS_POKEMON_TOWER_3F_ITEM", c_uint8, 1),
("HS_POKEMON_TOWER_4F_ITEM_1", c_uint8, 1),
("HS_POKEMON_TOWER_4F_ITEM_2", c_uint8, 1),
("HS_POKEMON_TOWER_4F_ITEM_3", c_uint8, 1),
("HS_POKEMON_TOWER_5F_ITEM", c_uint8, 1),
("HS_POKEMON_TOWER_6F_ITEM_1", c_uint8, 1),
("HS_POKEMON_TOWER_6F_ITEM_2", c_uint8, 1),
("HS_POKEMON_TOWER_7F_ROCKET_1", c_uint8, 1),
("HS_POKEMON_TOWER_7F_ROCKET_2", c_uint8, 1),
("HS_POKEMON_TOWER_7F_ROCKET_3", c_uint8, 1),
("HS_POKEMON_TOWER_7F_MR_FUJI", c_uint8, 1),
("HS_MR_FUJIS_HOUSE_MR_FUJI", c_uint8, 1),
("HS_CELADON_MANSION_EEVEE_GIFT", c_uint8, 1),
("HS_GAME_CORNER_ROCKET", c_uint8, 1),
("HS_WARDENS_HOUSE_ITEM", c_uint8, 1),
("HS_POKEMON_MANSION_1F_ITEM_1", c_uint8, 1),
("HS_POKEMON_MANSION_1F_ITEM_2", c_uint8, 1),
("HS_FIGHTING_DOJO_GIFT_1", c_uint8, 1),
("HS_FIGHTING_DOJO_GIFT_2", c_uint8, 1),
("HS_SILPH_CO_1F_RECEPTIONIST", c_uint8, 1),
("HS_VOLTORB_1", c_uint8, 1),
("HS_VOLTORB_2", c_uint8, 1),
("HS_VOLTORB_3", c_uint8, 1),
("HS_ELECTRODE_1", c_uint8, 1),
("HS_VOLTORB_4", c_uint8, 1),
("HS_VOLTORB_5", c_uint8, 1),
("HS_ELECTRODE_2", c_uint8, 1),
("HS_VOLTORB_6", c_uint8, 1),
("HS_ZAPDOS", c_uint8, 1),
("HS_POWER_PLANT_ITEM_1", c_uint8, 1),
("HS_POWER_PLANT_ITEM_2", c_uint8, 1),
("HS_POWER_PLANT_ITEM_3", c_uint8, 1),
("HS_POWER_PLANT_ITEM_4", c_uint8, 1),
("HS_POWER_PLANT_ITEM_5", c_uint8, 1),
("HS_MOLTRES", c_uint8, 1),
("HS_VICTORY_ROAD_2F_ITEM_1", c_uint8, 1),
("HS_VICTORY_ROAD_2F_ITEM_2", c_uint8, 1),
("HS_VICTORY_ROAD_2F_ITEM_3", c_uint8, 1),
("HS_VICTORY_ROAD_2F_ITEM_4", c_uint8, 1),
("HS_VICTORY_ROAD_2F_BOULDER", c_uint8, 1),
("HS_BILL_POKEMON", c_uint8, 1),
("HS_BILL_1", c_uint8, 1),
("HS_BILL_2", c_uint8, 1),
("HS_VIRIDIAN_FOREST_ITEM_1", c_uint8, 1),
("HS_VIRIDIAN_FOREST_ITEM_2", c_uint8, 1),
("HS_VIRIDIAN_FOREST_ITEM_3", c_uint8, 1),
("HS_MT_MOON_1F_ITEM_1", c_uint8, 1),
("HS_MT_MOON_1F_ITEM_2", c_uint8, 1),
("HS_MT_MOON_1F_ITEM_3", c_uint8, 1),
("HS_MT_MOON_1F_ITEM_4", c_uint8, 1),
("HS_MT_MOON_1F_ITEM_5", c_uint8, 1),
("HS_MT_MOON_1F_ITEM_6", c_uint8, 1),
("HS_MT_MOON_B2F_FOSSIL_1", c_uint8, 1),
("HS_MT_MOON_B2F_FOSSIL_2", c_uint8, 1),
("HS_MT_MOON_B2F_ITEM_1", c_uint8, 1),
("HS_MT_MOON_B2F_ITEM_2", c_uint8, 1),
("HS_SS_ANNE_2F_RIVAL", c_uint8, 1),
("HS_SS_ANNE_1F_ROOMS_ITEM", c_uint8, 1),
("HS_SS_ANNE_2F_ROOMS_ITEM_1", c_uint8, 1),
("HS_SS_ANNE_2F_ROOMS_ITEM_2", c_uint8, 1),
("HS_SS_ANNE_B1F_ROOMS_ITEM_1", c_uint8, 1),
("HS_SS_ANNE_B1F_ROOMS_ITEM_2", c_uint8, 1),
("HS_SS_ANNE_B1F_ROOMS_ITEM_3", c_uint8, 1),
("HS_VICTORY_ROAD_3F_ITEM_1", c_uint8, 1),
("HS_VICTORY_ROAD_3F_ITEM_2", c_uint8, 1),
("HS_VICTORY_ROAD_3F_BOULDER", c_uint8, 1),
("HS_ROCKET_HIDEOUT_B1F_ITEM_1", c_uint8, 1),
("HS_ROCKET_HIDEOUT_B1F_ITEM_2", c_uint8, 1),
("HS_ROCKET_HIDEOUT_B2F_ITEM_1", c_uint8, 1),
("HS_ROCKET_HIDEOUT_B2F_ITEM_2", c_uint8, 1),
("HS_ROCKET_HIDEOUT_B2F_ITEM_3", c_uint8, 1),
("HS_ROCKET_HIDEOUT_B2F_ITEM_4", c_uint8, 1),
("HS_ROCKET_HIDEOUT_B3F_ITEM_1", c_uint8, 1),
("HS_ROCKET_HIDEOUT_B3F_ITEM_2", c_uint8, 1),
("HS_ROCKET_HIDEOUT_B4F_GIOVANNI", c_uint8, 1),
("HS_ROCKET_HIDEOUT_B4F_ITEM_1", c_uint8, 1),
("HS_ROCKET_HIDEOUT_B4F_ITEM_2", c_uint8, 1),
("HS_ROCKET_HIDEOUT_B4F_ITEM_3", c_uint8, 1),
("HS_ROCKET_HIDEOUT_B4F_ITEM_4", c_uint8, 1),
("HS_ROCKET_HIDEOUT_B4F_ITEM_5", c_uint8, 1),
("HS_SILPH_CO_2F_1", c_uint8, 1),
("HS_SILPH_CO_2F_2", c_uint8, 1),
("HS_SILPH_CO_2F_3", c_uint8, 1),
("HS_SILPH_CO_2F_4", c_uint8, 1),
("HS_SILPH_CO_2F_5", c_uint8, 1),
("HS_SILPH_CO_3F_1", c_uint8, 1),
("HS_SILPH_CO_3F_2", c_uint8, 1),
("HS_SILPH_CO_3F_ITEM", c_uint8, 1),
("HS_SILPH_CO_4F_1", c_uint8, 1),
("HS_SILPH_CO_4F_2", c_uint8, 1),
("HS_SILPH_CO_4F_3", c_uint8, 1),
("HS_SILPH_CO_4F_ITEM_1", c_uint8, 1),
("HS_SILPH_CO_4F_ITEM_2", c_uint8, 1),
("HS_SILPH_CO_4F_ITEM_3", c_uint8, 1),
("HS_SILPH_CO_5F_1", c_uint8, 1),
("HS_SILPH_CO_5F_2", c_uint8, 1),
("HS_SILPH_CO_5F_3", c_uint8, 1),
("HS_SILPH_CO_5F_4", c_uint8, 1),
("HS_SILPH_CO_5F_ITEM_1", c_uint8, 1),
("HS_SILPH_CO_5F_ITEM_2", c_uint8, 1),
("HS_SILPH_CO_5F_ITEM_3", c_uint8, 1),
("HS_SILPH_CO_6F_1", c_uint8, 1),
("HS_SILPH_CO_6F_2", c_uint8, 1),
("HS_SILPH_CO_6F_3", c_uint8, 1),
("HS_SILPH_CO_6F_ITEM_1", c_uint8, 1),
("HS_SILPH_CO_6F_ITEM_2", c_uint8, 1),
("HS_SILPH_CO_7F_1", c_uint8, 1),
("HS_SILPH_CO_7F_2", c_uint8, 1),
("HS_SILPH_CO_7F_3", c_uint8, 1),
("HS_SILPH_CO_7F_4", c_uint8, 1),
("HS_SILPH_CO_7F_RIVAL", c_uint8, 1),
("HS_SILPH_CO_7F_ITEM_1", c_uint8, 1),
("HS_SILPH_CO_7F_ITEM_2", c_uint8, 1),
("HS_SILPH_CO_7F_8", c_uint8, 1),
("HS_SILPH_CO_8F_1", c_uint8, 1),
("HS_SILPH_CO_8F_2", c_uint8, 1),
("HS_SILPH_CO_8F_3", c_uint8, 1),
("HS_SILPH_CO_9F_1", c_uint8, 1),
("HS_SILPH_CO_9F_2", c_uint8, 1),
("HS_SILPH_CO_9F_3", c_uint8, 1),
("HS_SILPH_CO_10F_1", c_uint8, 1),
("HS_SILPH_CO_10F_2", c_uint8, 1),
("HS_SILPH_CO_10F_3", c_uint8, 1),
("HS_SILPH_CO_10F_ITEM_1", c_uint8, 1),
("HS_SILPH_CO_10F_ITEM_2", c_uint8, 1),
("HS_SILPH_CO_10F_ITEM_3", c_uint8, 1),
("HS_SILPH_CO_11F_1", c_uint8, 1),
("HS_SILPH_CO_11F_2", c_uint8, 1),
("HS_SILPH_CO_11F_3", c_uint8, 1),
("HS_UNUSED_MAP_F4_1", c_uint8, 1),
("HS_POKEMON_MANSION_2F_ITEM", c_uint8, 1),
("HS_POKEMON_MANSION_3F_ITEM_1", c_uint8, 1),
("HS_POKEMON_MANSION_3F_ITEM_2", c_uint8, 1),
("HS_POKEMON_MANSION_B1F_ITEM_1", c_uint8, 1),
("HS_POKEMON_MANSION_B1F_ITEM_2", c_uint8, 1),
("HS_POKEMON_MANSION_B1F_ITEM_3", c_uint8, 1),
("HS_POKEMON_MANSION_B1F_ITEM_4", c_uint8, 1),
("HS_POKEMON_MANSION_B1F_ITEM_5", c_uint8, 1),
("HS_SAFARI_ZONE_EAST_ITEM_1", c_uint8, 1),
("HS_SAFARI_ZONE_EAST_ITEM_2", c_uint8, 1),
("HS_SAFARI_ZONE_EAST_ITEM_3", c_uint8, 1),
("HS_SAFARI_ZONE_EAST_ITEM_4", c_uint8, 1),
("HS_SAFARI_ZONE_NORTH_ITEM_1", c_uint8, 1),
("HS_SAFARI_ZONE_NORTH_ITEM_2", c_uint8, 1),
("HS_SAFARI_ZONE_WEST_ITEM_1", c_uint8, 1),
("HS_SAFARI_ZONE_WEST_ITEM_2", c_uint8, 1),
("HS_SAFARI_ZONE_WEST_ITEM_3", c_uint8, 1),
("HS_SAFARI_ZONE_WEST_ITEM_4", c_uint8, 1),
("HS_SAFARI_ZONE_CENTER_ITEM", c_uint8, 1),
("HS_CERULEAN_CAVE_2F_ITEM_1", c_uint8, 1),
("HS_CERULEAN_CAVE_2F_ITEM_2", c_uint8, 1),
("HS_CERULEAN_CAVE_2F_ITEM_3", c_uint8, 1),
("HS_MEWTWO", c_uint8, 1),
("HS_CERULEAN_CAVE_B1F_ITEM_1", c_uint8, 1),
("HS_CERULEAN_CAVE_B1F_ITEM_2", c_uint8, 1),
("HS_VICTORY_ROAD_1F_ITEM_1", c_uint8, 1),
("HS_VICTORY_ROAD_1F_ITEM_2", c_uint8, 1),
("HS_CHAMPIONS_ROOM_OAK", c_uint8, 1),
("HS_SEAFOAM_ISLANDS_1F_BOULDER_1", c_uint8, 1),
("HS_SEAFOAM_ISLANDS_1F_BOULDER_2", c_uint8, 1),
("HS_SEAFOAM_ISLANDS_B1F_BOULDER_1", c_uint8, 1),
("HS_SEAFOAM_ISLANDS_B1F_BOULDER_2", c_uint8, 1),
("HS_SEAFOAM_ISLANDS_B2F_BOULDER_1", c_uint8, 1),
("HS_SEAFOAM_ISLANDS_B2F_BOULDER_2", c_uint8, 1),
("HS_SEAFOAM_ISLANDS_B3F_BOULDER_1", c_uint8, 1),
("HS_SEAFOAM_ISLANDS_B3F_BOULDER_2", c_uint8, 1),
("HS_SEAFOAM_ISLANDS_B3F_BOULDER_3", c_uint8, 1),
("HS_SEAFOAM_ISLANDS_B3F_BOULDER_4", c_uint8, 1),
("HS_SEAFOAM_ISLANDS_B4F_BOULDER_1", c_uint8, 1),
("HS_SEAFOAM_ISLANDS_B4F_BOULDER_2", c_uint8, 1),
("HS_ARTICUNO", c_uint8, 1),
]


class MissableFlags(Union):
# These missable flags is a 32 byte object
_fields_ = [("b", MissableFlagsBits), ("asbytes", c_uint8 * 32)]

def __init__(self, emu: PyBoy):
super().__init__()
self.asbytes = (c_uint8 * 32)(*emu.memory[0xD5A6 : 0xD5A6 + 32])

def get_missable(self, missable: str) -> bool:
return bool(getattr(self.b, missable))
14 changes: 13 additions & 1 deletion pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
USEFUL_ITEMS,
Items,
)
from pokemonred_puffer.data.missable_objects import MissableFlags
from pokemonred_puffer.data.strength_puzzles import STRENGTH_SOLUTIONS
from pokemonred_puffer.data.tilesets import Tilesets
from pokemonred_puffer.data.tm_hm import (
Expand Down Expand Up @@ -166,6 +167,8 @@ def __init__(self, env_config: pufferlib.namespace):
low=0, high=max(Items._value2member_map_.keys()), shape=(20,), dtype=np.uint8
),
"bag_quantity": spaces.Box(low=0, high=100, shape=(20,), dtype=np.uint8),
"rival_3": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
"game_corner_rocket": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
} | {
event: spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8)
for event in REQUIRED_EVENTS
Expand Down Expand Up @@ -290,6 +293,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] =
self.reset_mem()

self.events = EventFlags(self.pyboy)
self.missables = MissableFlags(self.pyboy)
self.update_pokedex()
self.update_tm_hm_moves_obtained()
self.taught_cut = self.check_if_party_has_hm(0xF)
Expand Down Expand Up @@ -515,6 +519,10 @@ def _get_obs(self):
"wJoyIgnore": np.array(self.read_m("wJoyIgnore"), dtype=np.uint8),
"bag_items": bag[::2].copy(),
"bag_quantity": bag[1::2].copy(),
"rival_3": np.array(self.read_m("wSSAnne2FCurScript") == 4, dtype=np.uint8),
"game_corner_rocket": np.array(
self.missables.get_missable("HS_GAME_CORNER_ROCKET"), dtype=np.uint8
),
}
| {event: np.array(self.events.get_event(event)) for event in REQUIRED_EVENTS}
)
Expand Down Expand Up @@ -554,6 +562,7 @@ def step(self, action):

self.run_action_on_emulator(action)
self.events = EventFlags(self.pyboy)
self.missables = MissableFlags(self.pyboy)
self.update_seen_coords()
self.update_health()
self.update_pokedex()
Expand Down Expand Up @@ -1127,7 +1136,10 @@ def agent_stats(self, action):
}
| {f"badge_{i+1}": bool(badges & (1 << i)) for i in range(8)},
"events": {event: self.events.get_event(event) for event in REQUIRED_EVENTS}
| {"rival3": int(self.read_m(0xD665) == 4)},
| {
"rival3": int(self.read_m(0xD665) == 4),
"game_corner_rocket": self.missables.get_missable("HS_GAME_CORNER_ROCKET"),
},
"required_items": {item.name: item.value in bag_item_ids for item in REQUIRED_ITEMS},
"useful_items": {item.name: item.value in bag_item_ids for item in USEFUL_ITEMS},
"reward": self.get_game_state_reward(),
Expand Down
2 changes: 2 additions & 0 deletions pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def encode_observations(self, observations):
blackout_map_id.squeeze(1),
observations["wJoyIgnore"].float(),
items.flatten(start_dim=1),
observations["rival_3"].float(),
observations["game_corner_rocket"].float(),
)
+ tuple(observations[event].float() for event in REQUIRED_EVENTS),
dim=-1,
Expand Down
2 changes: 2 additions & 0 deletions pokemonred_puffer/rewards/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ def get_game_state_reward(self):
"seen_action_bag_menu": self.seen_action_bag_menu
* self.reward_config["seen_action_bag_menu"],
"rival3": self.reward_config["event"] * int(self.read_m("wSSAnne2FCurScript") == 4),
"game_corner_rocket": self.reward_config["event"]
* float(self.missables.get_missable("HS_GAME_CORNER_ROCKET")),
}
| {
event: self.reward_config["required_event"] * float(self.events.get_event(event))
Expand Down

0 comments on commit 467df24

Please sign in to comment.