From ba12867da5da01bb30e8bc4c81a0d5d6a0e207ce Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Sun, 11 Aug 2024 18:16:10 -0400 Subject: [PATCH] dry --- pokemonred_puffer/environment.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 555c25f..0ca28cb 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -682,12 +682,8 @@ def step(self, action): self.step_count += 1 reset = ( - self.step_count - >= min( - self.max_steps, - self.max_steps - * (len(self.required_events) + len(self.required_items) * self.max_steps_scaling), - ) # or + self.step_count >= self.get_max_steps() + # or # self.caught_pokemon[6] == 1 # squirtle ) @@ -1294,9 +1290,7 @@ def agent_stats(self, action): "pokecenter_heal": self.pokecenter_heal, "in_battle": self.read_m("wIsInBattle") > 0, "event": self.progress_reward["event"], - "max_steps": self.max_steps - * (len(self.required_events) + len(self.required_items)) - * self.max_steps_scaling, + "max_steps": self.get_max_steps(), } | { "exploration": { @@ -1359,6 +1353,13 @@ def add_video_frame(self): def get_game_coords(self): return (self.read_m(0xD362), self.read_m(0xD361), self.read_m(0xD35E)) + def get_max_steps(self): + return min( + self.max_steps, + self.max_steps + * (len(self.required_events) + len(self.required_items) * self.max_steps_scaling), + ) + def update_seen_coords(self): inc = 0.5 if (self.read_m("wd736") & 0b1000_0000) else self.exploration_inc