diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index f57842c..b60d799 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -256,9 +256,9 @@ def encode_observations(self, observations): ) + (() if self.skip_safari_zone else (observations["safari_steps"].float() / 502.0,)) + ( - () + (self.global_map_network(global_map.float() / 255.0).squeeze(1),) if self.use_global_map - else (self.global_map_network(global_map.float() / 255.0).squeeze(1),) + else () ), dim=-1, )