diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index ae32a50..c38ac63 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -173,7 +173,7 @@ def encode_observations(self, observations): # party network species = self.species_embeddings(observations["species"].squeeze(1).long()).float() - status = one_hot(observations["status"].long(), 7).squeeze(1) + status = one_hot(observations["status"].long(), 7).squeeze(1).float() type1 = self.type_embeddings(observations["type1"].squeeze(1).long()).float() type2 = self.type_embeddings(observations["type2"].squeeze(1).long()).float() moves = (