From bf5f1a56388d78a0ca9bf9363df9a59ddc784a68 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Sun, 28 Jul 2024 23:36:12 -0400 Subject: [PATCH] use int conversion for embedding keys --- .../policies/multi_convolutional.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index f1f13ac..158406f 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -155,12 +155,12 @@ def encode_observations(self, observations): .int(), ).reshape(restored_global_map_shape) # badges = self.badge_buffer <= observations["badges"] - map_id = self.map_embeddings(observations["map_id"].long()).squeeze(1) - blackout_map_id = self.map_embeddings(observations["blackout_map_id"].long()).squeeze(1) + map_id = self.map_embeddings(observations["map_id"].int()).squeeze(1) + blackout_map_id = self.map_embeddings(observations["blackout_map_id"].int()).squeeze(1) # The bag quantity can be a value between 1 and 99 # TODO: Should items be positionally encoded? I dont think it matters items = ( - self.item_embeddings(observations["bag_items"].long()) + self.item_embeddings(observations["bag_items"].int()) * (observations["bag_quantity"].float().unsqueeze(-1) / 100.0) ).squeeze(1) @@ -176,12 +176,12 @@ def encode_observations(self, observations): image_observation = image_observation[:, :, :: self.downsample, :: self.downsample] # party network - species = self.species_embeddings(observations["species"].squeeze(1).long()).float() - status = one_hot(observations["status"].long(), 7).float().squeeze(1) - type1 = self.type_embeddings(observations["type1"].long()).squeeze(1) - type2 = self.type_embeddings(observations["type2"].long()).squeeze(1) + species = self.species_embeddings(observations["species"].squeeze(1).int()).float() + status = one_hot(observations["status"].int(), 7).float().squeeze(1) + type1 = self.type_embeddings(observations["type1"].int()).squeeze(1) + type2 = self.type_embeddings(observations["type2"].int()).squeeze(1) moves = ( - self.moves_embeddings(observations["moves"].squeeze(1).long()) + self.moves_embeddings(observations["moves"].squeeze(1).int()) .float() .reshape((-1, 6, 4 * self.moves_embeddings.embedding_dim)) ) @@ -210,14 +210,14 @@ def encode_observations(self, observations): cat_obs = torch.cat( ( self.screen_network(image_observation.float() / 255.0).squeeze(1), - one_hot(observations["direction"].long(), 4).float().squeeze(1), - # one_hot(observations["reset_map_id"].long(), 0xF7).float().squeeze(1), - one_hot(observations["battle_type"].long(), 4).float().squeeze(1), + one_hot(observations["direction"].int(), 4).float().squeeze(1), + # one_hot(observations["reset_map_id"].int(), 0xF7).float().squeeze(1), + one_hot(observations["battle_type"].int(), 4).float().squeeze(1), # observations["cut_event"].float(), observations["cut_in_party"].float(), # observations["x"].float(), # observations["y"].float(), - # one_hot(observations["map_id"].long(), 0xF7).float().squeeze(1), + # one_hot(observations["map_id"].int(), 0xF7).float().squeeze(1), # badges.float().squeeze(1), map_id.squeeze(1), blackout_map_id.squeeze(1),