Skip to content

Commit

Permalink
Remember to cast all network inputs to float
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 29, 2024
1 parent 4e9ba5f commit 06fdffa
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def encode_observations(self, observations):

# party network
species = self.species_embeddings(observations["species"].squeeze(1).long()).float()
status = one_hot(observations["status"].long(), 7).squeeze(1)
status = one_hot(observations["status"].long(), 7).squeeze(1).float()
type1 = self.type_embeddings(observations["type1"].squeeze(1).long()).float()
type2 = self.type_embeddings(observations["type2"].squeeze(1).long()).float()
moves = (
Expand Down

0 comments on commit 06fdffa

Please sign in to comment.