diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 02738cb..84cfbb3 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -170,6 +170,19 @@ def __init__(self, env_config: pufferlib.namespace): "rival_3": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), "game_corner_rocket": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), "saffron_guard": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), + # This could be a dict within a sequence, but we'll do it like this and concat later + "species": spaces.Box(low=0, high=0xBE, shape=(6,), dtype=np.uint8), + "hp": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16), + "status": spaces.Box(low=0, high=7, shape=(6,), dtype=np.uint8), + "type1": spaces.Box(low=0, high=0x1A, shape=(6,), dtype=np.uint8), + "type2": spaces.Box(low=0, high=0x1A, shape=(6,), dtype=np.uint8), + "level": spaces.Box(low=0, high=100, shape=(6,), dtype=np.uint8), + "maxHP": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16), + "attack": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16), + "defense": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16), + "speed": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16), + "special": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16), + "moves": spaces.Box(low=0, high=0xA4, shape=(6, 4), dtype=np.uint8), } | { event: spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8) for event in REQUIRED_EVENTS @@ -284,6 +297,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.events = EventFlags(self.pyboy) self.missables = MissableFlags(self.pyboy) self.wd728 = Wd728Flags(self.pyboy) + self.party = PartyMons(self.pyboy) self.update_pokedex() self.update_tm_hm_moves_obtained() self.taught_cut = self.check_if_party_has_hm(0xF) @@ -522,6 +536,18 @@ def _get_obs(self): "saffron_guard": np.array( self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK"), dtype=np.uint8 ), + "species": np.array([self.party[i].Species for i in range(6)], dtype=np.uint8), + "hp": np.array([self.party[i].HP for i in range(6)], dtype=np.uint16), + "status": np.array([self.party[i].Status for i in range(6)], dtype=np.uint8), + "type1": np.array([self.party[i].Type1 for i in range(6)], dtype=np.uint8), + "type2": np.array([self.party[i].Type2 for i in range(6)], dtype=np.uint8), + "level": np.array([self.party[i].Level for i in range(6)], dtype=np.uint8), + "maxHP": np.array([self.party[i].MaxHP for i in range(6)], dtype=np.uint16), + "attack": np.array([self.party[i].Attack for i in range(6)], dtype=np.uint16), + "defense": np.array([self.party[i].Defense for i in range(6)], dtype=np.uint16), + "speed": np.array([self.party[i].Speed for i in range(6)], dtype=np.uint16), + "special": np.array([self.party[i].Special for i in range(6)], dtype=np.uint16), + "moves": np.array([self.party[i].Moves for i in range(6)], dtype=np.uint8), } | {event: np.array(self.events.get_event(event)) for event in REQUIRED_EVENTS} ) @@ -565,6 +591,7 @@ def step(self, action): self.run_action_on_emulator(action) self.events = EventFlags(self.pyboy) self.missables = MissableFlags(self.pyboy) + self.wd728 = Wd728Flags(self.pyboy) self.party = PartyMons(self.pyboy) self.update_seen_coords() self.update_health() diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index 80d895e..ae32a50 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -96,13 +96,19 @@ def __init__( # pokemon has 0xF7 map ids # Lets start with 4 dims for now. Could try 8 - self.map_embeddings = torch.nn.Embedding(0xF7, 4, dtype=torch.float32) + self.map_embeddings = nn.Embedding(0xF7, 4, dtype=torch.float32) # N.B. This is an overestimate item_count = max(Items._value2member_map_.keys()) - self.item_embeddings = torch.nn.Embedding( + self.item_embeddings = nn.Embedding( item_count, int(item_count**0.25 + 1), dtype=torch.float32 ) + # Party layers + self.party_network = nn.Sequential(nn.LazyLinear(6), nn.ReLU(), nn.Flatten()) + self.species_embeddings = nn.Embedding(0xBE, int(0xBE**0.25) + 1, dtype=torch.float32) + self.type_embeddings = nn.Embedding(0x1A, int(0x1A**0.25) + 1, dtype=torch.float32) + self.moves_embeddings = nn.Embedding(0xA4, int(0xA4**0.25) + 1, dtype=torch.float32) + def forward(self, observations): hidden, lookup = self.encode_observations(observations) actions, value = self.decode_actions(hidden, lookup) @@ -165,6 +171,35 @@ def encode_observations(self, observations): if self.downsample > 1: image_observation = image_observation[:, :, :: self.downsample, :: self.downsample] + # party network + species = self.species_embeddings(observations["species"].squeeze(1).long()).float() + status = one_hot(observations["status"].long(), 7).squeeze(1) + type1 = self.type_embeddings(observations["type1"].squeeze(1).long()).float() + type2 = self.type_embeddings(observations["type2"].squeeze(1).long()).float() + moves = ( + self.moves_embeddings(observations["moves"].squeeze(1).long()) + .float() + .reshape((-1, 6, 4 * self.moves_embeddings.embedding_dim)) + ) + party_obs = torch.cat( + ( + species, + observations["hp"].float().unsqueeze(-1) / 714.0, + status, + type1, + type2, + observations["level"].float().unsqueeze(-1) / 100.0, + observations["maxHP"].float().unsqueeze(-1) / 714.0, + observations["attack"].float().unsqueeze(-1) / 714.0, + observations["defense"].float().unsqueeze(-1) / 714.0, + observations["speed"].float().unsqueeze(-1) / 714.0, + observations["special"].float().unsqueeze(-1) / 714.0, + moves, + ), + dim=-1, + ) + party_latent = self.party_network(party_obs) + cat_obs = torch.cat( ( self.screen_network(image_observation.float() / 255.0).squeeze(1), @@ -184,6 +219,7 @@ def encode_observations(self, observations): observations["rival_3"].float(), observations["game_corner_rocket"].float(), observations["saffron_guard"].float(), + party_latent, ) + tuple(observations[event].float() for event in REQUIRED_EVENTS), dim=-1,