From b6c0982e5c478de2d23d76375d74f3d1f9f34153 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Tue, 16 Apr 2024 22:48:33 -0400 Subject: [PATCH] less np.sum abuse --- pokemonred_puffer/rewards/baseline.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index a46603a..4ecbc27 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -214,14 +214,14 @@ class RockTunnelReplicationEnv(TeachCutReplicationEnv): def get_game_state_reward(self): return { "level": self.reward_config["level"] * self.get_levels_reward(), - "exploration": self.reward_config["exploration"] * np.sum(self.seen_coords.values()), + "exploration": self.reward_config["exploration"] * sum(self.seen_coords.values()), "taught_cut": self.reward_config["taught_cut"] * int(self.taught_cut), "event": self.reward_config["event"] * self.update_max_event_rew(), "seen_pokemon": self.reward_config["seen_pokemon"] * np.sum(self.seen_pokemon), "caught_pokemon": self.reward_config["caught_pokemon"] * np.sum(self.caught_pokemon), "moves_obtained": self.reward_config["moves_obtained"] * np.sum(self.moves_obtained), - "cut_coords": self.reward_config["cut_coords"] * np.sum(self.cut_coords.values()), - "cut_tiles": self.reward_config["cut_tiles"] * np.sum(self.cut_tiles), + "cut_coords": self.reward_config["cut_coords"] * sum(self.cut_coords.values()), + "cut_tiles": self.reward_config["cut_tiles"] * sum(self.cut_tiles), "start_menu": ( self.reward_config["start_menu"] * self.seen_start_menu * int(self.taught_cut) ),