diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index 22f75b2..7636a51 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -232,7 +232,7 @@ def encode_observations(self, observations): ), dim=-1, ) - if self.skip_safari_zone: + if not self.skip_safari_zone: cat_obs = torch.cat( ( cat_obs,