From 5801c11d47edd9b3ea0af17c25df2561f5080ae4 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Sat, 7 Dec 2024 09:57:03 -0500 Subject: [PATCH] maybe we cant do successive cat --- .../policies/multi_convolutional.py | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index c90a713..f57842c 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -253,25 +253,15 @@ def encode_observations(self, observations): observations["game_corner_rocket"].float(), observations["saffron_guard"].float(), observations["lapras"].float(), + ) + + (() if self.skip_safari_zone else (observations["safari_steps"].float() / 502.0,)) + + ( + () + if self.use_global_map + else (self.global_map_network(global_map.float() / 255.0).squeeze(1),) ), dim=-1, ) - if not self.skip_safari_zone: - cat_obs = torch.cat( - ( - cat_obs, - observations["safari_steps"].float() / 502.0, - ), - dim=-1, - ) - if self.use_global_map: - cat_obs = torch.cat( - ( - cat_obs, - self.global_map_network(global_map.float() / 255.0).squeeze(1), - ), - dim=-1, - ) return self.encode_linear(cat_obs), None def decode_actions(self, flat_hidden, lookup, concat=None):