diff --git a/pokemonred_puffer/data/events.py b/pokemonred_puffer/data/events.py index c4e1a3f..5f87f46 100644 --- a/pokemonred_puffer/data/events.py +++ b/pokemonred_puffer/data/events.py @@ -1,4 +1,5 @@ from ctypes import c_uint8, LittleEndianStructure, Union +import re from pyboy import PyBoy @@ -2585,6 +2586,11 @@ def get_event(self, event_name: str) -> bool: return bool(getattr(self.b, event_name)) +EVENTS = { + event for event, _ in EventFlagsBits._fields_ if not re.match("EVENT_[0-9,A-F]{3}$", event) +} + + REQUIRED_EVENTS = { "EVENT_FOLLOWED_OAK_INTO_LAB", "EVENT_PALLET_AFTER_GETTING_POKEBALLS", diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index f16a332..819be09 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -18,6 +18,7 @@ import pufferlib from pokemonred_puffer.data.events import ( EVENT_FLAGS_START, + EVENTS, EVENTS_FLAGS_LENGTH, MUSEUM_TICKET, REQUIRED_EVENTS, @@ -186,9 +187,7 @@ def __init__(self, env_config: pufferlib.namespace): "special": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint32), "moves": spaces.Box(low=0, high=0xA4, shape=(6, 4), dtype=np.uint8), # Add 3 for rival_3, game corner rocket and saffron guard - "required_events": spaces.Box( - low=0, high=1, shape=(len(REQUIRED_EVENTS) + 3,), dtype=np.uint8 - ), + "events": spaces.Box(low=0, high=1, shape=(len(EVENTS) + 3,), dtype=np.uint8), } if self.use_global_map: @@ -570,8 +569,8 @@ def _get_obs(self): "speed": np.array([self.party[i].Speed for i in range(6)], dtype=np.uint32), "special": np.array([self.party[i].Special for i in range(6)], dtype=np.uint32), "moves": np.array([self.party[i].Moves for i in range(6)], dtype=np.uint8), - "required_events": np.array( - [self.events.get_event(event) for event in REQUIRED_EVENTS] + "events": np.array( + [self.events.get_event(event) for event in EVENTS] + [ self.read_m("wSSAnne2FCurScript") == 4, # rival 3 self.missables.get_missable("HS_GAME_CORNER_ROCKET"), # game corner rocket diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index 8ce9f2b..f1f13ac 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -109,7 +109,7 @@ def __init__( self.moves_embeddings = nn.Embedding(0xA4, int(0xA4**0.25) + 1, dtype=torch.float32) # event embeddings - n_events = env.env.observation_space["required_events"].shape[0] + n_events = env.env.observation_space["events"].shape[0] self.event_embeddings = nn.Embedding(n_events, int(n_events**0.25) + 1, dtype=torch.float32) def forward(self, observations): @@ -205,7 +205,7 @@ def encode_observations(self, observations): party_latent = self.party_network(party_obs) event_obs = ( - observations["required_events"].float() @ self.event_embeddings.weight + observations["events"].float() @ self.event_embeddings.weight ) / self.event_embeddings.weight.shape[0] cat_obs = torch.cat( (