From 97f65f26baae6f43ded230cd94cc9f73f88ad80b Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Thu, 25 Jul 2024 08:20:58 -0400 Subject: [PATCH] clean up event obs --- config.yaml | 6 ++-- pokemonred_puffer/environment.py | 21 ++++++------ .../policies/multi_convolutional.py | 32 ++++++++----------- 3 files changed, 26 insertions(+), 33 deletions(-) diff --git a/config.yaml b/config.yaml index 8b30488..605d275 100644 --- a/config.yaml +++ b/config.yaml @@ -19,7 +19,7 @@ debug: device: cpu compile: False compile_mode: default - num_envs: 16 + num_envs: 1 envs_per_worker: 1 num_workers: 1 env_batch_size: 32 @@ -37,8 +37,8 @@ debug: verbose: False env_pool: False load_optimizer_state: False - swarm_frequency: 1 - swarm_keep_pct: .1 + # swarm_frequency: 1 + # swarm_keep_pct: .1 env: headless: True diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 5f7d180..6a10e47 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -173,9 +173,6 @@ def __init__(self, env_config: pufferlib.namespace): low=0, high=max(Items._value2member_map_.keys()), shape=(20,), dtype=np.uint8 ), "bag_quantity": spaces.Box(low=0, high=100, shape=(20,), dtype=np.uint8), - "rival_3": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), - "game_corner_rocket": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), - "saffron_guard": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), # This could be a dict within a sequence, but we'll do it like this and concat later "species": spaces.Box(low=0, high=0xBE, shape=(6,), dtype=np.uint8), "hp": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16), @@ -189,8 +186,9 @@ def __init__(self, env_config: pufferlib.namespace): "speed": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16), "special": spaces.Box(low=0, high=714, shape=(6,), dtype=np.uint16), "moves": spaces.Box(low=0, high=0xA4, shape=(6, 4), dtype=np.uint8), + # Add 3 for rival_3, game corner rocket and saffron guard "required_events": spaces.Box( - low=0, high=1, shape=(len(REQUIRED_EVENTS),), dtype=np.uint8 + low=0, high=1, shape=(len(REQUIRED_EVENTS) + 3,), dtype=np.uint8 ), } @@ -559,13 +557,6 @@ def _get_obs(self): "wJoyIgnore": np.array(self.read_m("wJoyIgnore"), dtype=np.uint8), "bag_items": bag[::2].copy(), "bag_quantity": bag[1::2].copy(), - "rival_3": np.array(self.read_m("wSSAnne2FCurScript") == 4, dtype=np.uint8), - "game_corner_rocket": np.array( - self.missables.get_missable("HS_GAME_CORNER_ROCKET"), dtype=np.uint8 - ), - "saffron_guard": np.array( - self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK"), dtype=np.uint8 - ), "species": np.array([self.party[i].Species for i in range(6)], dtype=np.uint8), "hp": np.array([self.party[i].HP for i in range(6)], dtype=np.uint16), "status": np.array([self.party[i].Status for i in range(6)], dtype=np.uint8), @@ -579,7 +570,13 @@ def _get_obs(self): "special": np.array([self.party[i].Special for i in range(6)], dtype=np.uint16), "moves": np.array([self.party[i].Moves for i in range(6)], dtype=np.uint8), "required_events": np.array( - [self.events.get_event(event) for event in REQUIRED_EVENTS], dtype=np.uint8 + [self.events.get_event(event) for event in REQUIRED_EVENTS] + + [ + self.read_m("wSSAnne2FCurScript") == 4, # rival 3 + self.missables.get_missable("HS_GAME_CORNER_ROCKET"), # game corner rocket + self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK"), # saffron guard + ], + dtype=np.uint8, ), } diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index f51eea9..34bf21e 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -4,7 +4,6 @@ import torch from torch import nn -from pokemonred_puffer.data.events import REQUIRED_EVENTS from pokemonred_puffer.data.items import Items from pokemonred_puffer.environment import PIXEL_VALUES @@ -110,9 +109,8 @@ def __init__( self.moves_embeddings = nn.Embedding(0xA4, int(0xA4**0.25) + 1, dtype=torch.float32) # event embeddings - self.event_embeddings = nn.Embedding( - len(REQUIRED_EVENTS), int(len(REQUIRED_EVENTS) ** 0.25) + 1, dtype=torch.float32 - ) + n_events = env.env.observation_space["required_events"].shape[0] + self.event_embeddings = nn.Embedding(n_events, int(n_events**0.25) + 1, dtype=torch.float32) def forward(self, observations): hidden, lookup = self.encode_observations(observations) @@ -157,13 +155,14 @@ def encode_observations(self, observations): .int(), ).reshape(restored_global_map_shape) # badges = self.badge_buffer <= observations["badges"] - map_id = self.map_embeddings(observations["map_id"].long()) - blackout_map_id = self.map_embeddings(observations["blackout_map_id"].long()) + map_id = self.map_embeddings(observations["map_id"].long()).squeeze(1) + blackout_map_id = self.map_embeddings(observations["blackout_map_id"].long()).squeeze(1) # The bag quantity can be a value between 1 and 99 # TODO: Should items be positionally encoded? I dont think it matters - items = self.item_embeddings(observations["bag_items"].squeeze(1).long()).float() * ( - observations["bag_quantity"].squeeze(1).float().unsqueeze(-1) / 100.0 - ) + items = ( + self.item_embeddings(observations["bag_items"].long()) + * (observations["bag_quantity"].float().unsqueeze(-1) / 100.0) + ).squeeze(1) # image_observation = torch.cat((screen, visited_mask, global_map), dim=-1) image_observation = torch.cat((screen, visited_mask), dim=-1) @@ -178,9 +177,9 @@ def encode_observations(self, observations): # party network species = self.species_embeddings(observations["species"].squeeze(1).long()).float() - status = one_hot(observations["status"].long(), 7).squeeze(1).float() - type1 = self.type_embeddings(observations["type1"].squeeze(1).long()).float() - type2 = self.type_embeddings(observations["type2"].squeeze(1).long()).float() + status = one_hot(observations["status"].long(), 7).float().squeeze(1) + type1 = self.type_embeddings(observations["type1"].long()).squeeze(1) + type2 = self.type_embeddings(observations["type2"].long()).squeeze(1) moves = ( self.moves_embeddings(observations["moves"].squeeze(1).long()) .float() @@ -205,9 +204,9 @@ def encode_observations(self, observations): ) party_latent = self.party_network(party_obs) - event_obs = (observations["required_events"].float() @ self.event_embeddings.weight) / len( - REQUIRED_EVENTS - ) + event_obs = ( + observations["required_events"].float() @ self.event_embeddings.weight + ) / self.event_embeddings.weight.shape[0] cat_obs = torch.cat( ( self.screen_network(image_observation.float() / 255.0).squeeze(1), @@ -224,9 +223,6 @@ def encode_observations(self, observations): blackout_map_id.squeeze(1), observations["wJoyIgnore"].float(), items.flatten(start_dim=1), - observations["rival_3"].float(), - observations["game_corner_rocket"].float(), - observations["saffron_guard"].float(), party_latent, event_obs, ),