diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 2e3d72e..ac4646f 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -290,7 +290,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.seen_pokemon = np.zeros(152, dtype=np.uint8) self.caught_pokemon = np.zeros(152, dtype=np.uint8) self.moves_obtained = np.zeros(0xA5, dtype=np.uint8) - self.pokecenters = np.zeros(255, dtype=np.uint8) + self.pokecenters = np.zeros(252, dtype=np.uint8) # lazy random seed setting if not seed: seed = random.randint(0, 4096) @@ -877,8 +877,8 @@ def get_levels_reward(self): # Level reward party_levels = self.read_party() self.max_level_sum = max(self.max_level_sum, sum(party_levels)) - if self.max_level_sum < 15: + if self.max_level_sum < 30: level_reward = 1 * self.max_level_sum else: - level_reward = 15 + (self.max_level_sum - 15) / 4 + level_reward = 30 + (self.max_level_sum - 30) / 4 return level_reward diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index 4ecbc27..725d79b 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -245,3 +245,12 @@ def get_game_state_reward(self): * int(self.read_bit(0xD7F2, 7)), "rival3": self.reward_config["event"] * int(self.read_m(0xD665) == 4), } + + def get_levels_reward(self): + party_size = self.read_m("wPartyCount") + party_levels = [self.read_m(f"wPartyMon{i+1}Level") for i in range(party_size)] + self.max_level_sum = max(self.max_level_sum, sum(party_levels)) + if self.max_level_sum < 15: + return self.max_level_sum + else: + return 15 # + (self.max_level_sum - 15) / 4