Skip to content

Commit

Permalink
party model
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 29, 2024
1 parent 2e5c2ac commit 4e9ba5f
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 2 deletions.
27 changes: 27 additions & 0 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,19 @@ def __init__(self, env_config: pufferlib.namespace):
"rival_3": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
"game_corner_rocket": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
"saffron_guard": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
# This could be a dict within a sequence, but we'll do it like this and concat later
"species": spaces.Box(low=0, high=0xBE, shape=(6,), dtype=np.uint8),
"hp": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16),
"status": spaces.Box(low=0, high=7, shape=(6,), dtype=np.uint8),
"type1": spaces.Box(low=0, high=0x1A, shape=(6,), dtype=np.uint8),
"type2": spaces.Box(low=0, high=0x1A, shape=(6,), dtype=np.uint8),
"level": spaces.Box(low=0, high=100, shape=(6,), dtype=np.uint8),
"maxHP": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16),
"attack": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16),
"defense": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16),
"speed": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16),
"special": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16),
"moves": spaces.Box(low=0, high=0xA4, shape=(6, 4), dtype=np.uint8),
} | {
event: spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8)
for event in REQUIRED_EVENTS
Expand Down Expand Up @@ -284,6 +297,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] =
self.events = EventFlags(self.pyboy)
self.missables = MissableFlags(self.pyboy)
self.wd728 = Wd728Flags(self.pyboy)
self.party = PartyMons(self.pyboy)
self.update_pokedex()
self.update_tm_hm_moves_obtained()
self.taught_cut = self.check_if_party_has_hm(0xF)
Expand Down Expand Up @@ -522,6 +536,18 @@ def _get_obs(self):
"saffron_guard": np.array(
self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK"), dtype=np.uint8
),
"species": np.array([self.party[i].Species for i in range(6)], dtype=np.uint8),
"hp": np.array([self.party[i].HP for i in range(6)], dtype=np.uint16),
"status": np.array([self.party[i].Status for i in range(6)], dtype=np.uint8),
"type1": np.array([self.party[i].Type1 for i in range(6)], dtype=np.uint8),
"type2": np.array([self.party[i].Type2 for i in range(6)], dtype=np.uint8),
"level": np.array([self.party[i].Level for i in range(6)], dtype=np.uint8),
"maxHP": np.array([self.party[i].MaxHP for i in range(6)], dtype=np.uint16),
"attack": np.array([self.party[i].Attack for i in range(6)], dtype=np.uint16),
"defense": np.array([self.party[i].Defense for i in range(6)], dtype=np.uint16),
"speed": np.array([self.party[i].Speed for i in range(6)], dtype=np.uint16),
"special": np.array([self.party[i].Special for i in range(6)], dtype=np.uint16),
"moves": np.array([self.party[i].Moves for i in range(6)], dtype=np.uint8),
}
| {event: np.array(self.events.get_event(event)) for event in REQUIRED_EVENTS}
)
Expand Down Expand Up @@ -565,6 +591,7 @@ def step(self, action):
self.run_action_on_emulator(action)
self.events = EventFlags(self.pyboy)
self.missables = MissableFlags(self.pyboy)
self.wd728 = Wd728Flags(self.pyboy)
self.party = PartyMons(self.pyboy)
self.update_seen_coords()
self.update_health()
Expand Down
40 changes: 38 additions & 2 deletions pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,19 @@ def __init__(

# pokemon has 0xF7 map ids
# Lets start with 4 dims for now. Could try 8
self.map_embeddings = torch.nn.Embedding(0xF7, 4, dtype=torch.float32)
self.map_embeddings = nn.Embedding(0xF7, 4, dtype=torch.float32)
# N.B. This is an overestimate
item_count = max(Items._value2member_map_.keys())
self.item_embeddings = torch.nn.Embedding(
self.item_embeddings = nn.Embedding(
item_count, int(item_count**0.25 + 1), dtype=torch.float32
)

# Party layers
self.party_network = nn.Sequential(nn.LazyLinear(6), nn.ReLU(), nn.Flatten())
self.species_embeddings = nn.Embedding(0xBE, int(0xBE**0.25) + 1, dtype=torch.float32)
self.type_embeddings = nn.Embedding(0x1A, int(0x1A**0.25) + 1, dtype=torch.float32)
self.moves_embeddings = nn.Embedding(0xA4, int(0xA4**0.25) + 1, dtype=torch.float32)

def forward(self, observations):
hidden, lookup = self.encode_observations(observations)
actions, value = self.decode_actions(hidden, lookup)
Expand Down Expand Up @@ -165,6 +171,35 @@ def encode_observations(self, observations):
if self.downsample > 1:
image_observation = image_observation[:, :, :: self.downsample, :: self.downsample]

# party network
species = self.species_embeddings(observations["species"].squeeze(1).long()).float()
status = one_hot(observations["status"].long(), 7).squeeze(1)
type1 = self.type_embeddings(observations["type1"].squeeze(1).long()).float()
type2 = self.type_embeddings(observations["type2"].squeeze(1).long()).float()
moves = (
self.moves_embeddings(observations["moves"].squeeze(1).long())
.float()
.reshape((-1, 6, 4 * self.moves_embeddings.embedding_dim))
)
party_obs = torch.cat(
(
species,
observations["hp"].float().unsqueeze(-1) / 714.0,
status,
type1,
type2,
observations["level"].float().unsqueeze(-1) / 100.0,
observations["maxHP"].float().unsqueeze(-1) / 714.0,
observations["attack"].float().unsqueeze(-1) / 714.0,
observations["defense"].float().unsqueeze(-1) / 714.0,
observations["speed"].float().unsqueeze(-1) / 714.0,
observations["special"].float().unsqueeze(-1) / 714.0,
moves,
),
dim=-1,
)
party_latent = self.party_network(party_obs)

cat_obs = torch.cat(
(
self.screen_network(image_observation.float() / 255.0).squeeze(1),
Expand All @@ -184,6 +219,7 @@ def encode_observations(self, observations):
observations["rival_3"].float(),
observations["game_corner_rocket"].float(),
observations["saffron_guard"].float(),
party_latent,
)
+ tuple(observations[event].float() for event in REQUIRED_EVENTS),
dim=-1,
Expand Down

0 comments on commit 4e9ba5f

Please sign in to comment.