From 06fdffa60c89d7afd4b6440f2235a43c3f7d0cf9 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Sat, 29 Jun 2024 10:29:51 -0400 Subject: [PATCH] Remember to cast all network inputs to float --- pokemonred_puffer/policies/multi_convolutional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index ae32a50..c38ac63 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -173,7 +173,7 @@ def encode_observations(self, observations): # party network species = self.species_embeddings(observations["species"].squeeze(1).long()).float() - status = one_hot(observations["status"].long(), 7).squeeze(1) + status = one_hot(observations["status"].long(), 7).squeeze(1).float() type1 = self.type_embeddings(observations["type1"].squeeze(1).long()).float() type2 = self.type_embeddings(observations["type2"].squeeze(1).long()).float() moves = (