diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index bb5cf65..deaef2c 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -477,7 +477,10 @@ def _get_obs(self): # player_x, player_y, map_n = self.get_game_coords() _, wBagItems = self.pyboy.symbol_lookup("wBagItems") bag = self.pyboy.memory[wBagItems : wBagItems + 40] - end_of_bag = list(bag[::2]).index(0xFF) + try: + end_of_bag = 2 * list(bag[::2]).index(0xFF) + except ValueError: + end_of_bag = len(bag) bag = np.array(bag, dtype=np.uint8) bag[end_of_bag:] = 0 @@ -1055,7 +1058,8 @@ def disable_wild_encounter_hook(self, *args, **kwargs): def agent_stats(self, action): levels = [self.read_m(f"wPartyMon{i+1}Level") for i in range(self.read_m("wPartyCount"))] - badges = self.read_short("wObtainedBadges") + badges = self.read_m("wObtainedBadges") + breakpoint() return { "stats": { "step": self.step_count + self.reset_count * self.max_steps, @@ -1103,7 +1107,7 @@ def agent_stats(self, action): "rival3": int(self.read_m(0xD665) == 4), "rocket_hideout_found": int(self.read_bit(0xD77E, 1)), } - | {f"badge_{i+1}": badges & (1 << i) for i in range(8)}, + | {f"badge_{i+1}": bool(badges & (1 << i)) for i in range(8)}, "reward": self.get_game_state_reward(), "reward/reward_sum": sum(self.get_game_state_reward().values()), "pokemon_exploration_map": self.explore_map,