diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 78de5b8..fa95647 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -207,11 +207,11 @@ def __init__(self, env_config: pufferlib.namespace): "direction": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8), # "reset_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), "battle_type": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8), - # "cut_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), + "cut_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), # "x": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), # "y": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), # "map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), - # "badges": spaces.Box(low=0, high=8, shape=(1,), dtype=np.uint8), + "badges": spaces.Box(low=0, high=0xFFFF, shape=(1,), dtype=np.uint8), } ) @@ -515,11 +515,11 @@ def _get_obs(self): ), # "reset_map_id": np.array(self.read_m("wLastBlackoutMap"), dtype=np.uint8), "battle_type": np.array(self.read_m("wIsInBattle") + 1, dtype=np.uint8), - # "cut_in_party": np.array(self.check_if_party_has_cut(), dtype=np.uint8), + "cut_in_party": np.array(self.check_if_party_has_cut(), dtype=np.uint8), # "x": np.array(player_x, dtype=np.uint8), # "y": np.array(player_y, dtype=np.uint8), # "map_id": np.array(map_n, dtype=np.uint8), - # "badges": np.array(self.get_badges(), dtype=np.uint8), + "badges": np.array(self.read_m("wObtainedBadges"), dtype=np.uint8), } def set_perfect_iv_dvs(self): @@ -647,9 +647,9 @@ def cut_if_next(self): self.pyboy.tick(self.action_freq, render=True) # scroll to pokemon # 1 is the item index for pokemon - # for _ in range(24): - while self.pyboy.memory[self.pyboy.symbol_lookup("wCurrentMenuItem")[1]] != 1: - # break + for _ in range(24): + if self.pyboy.memory[self.pyboy.symbol_lookup("wCurrentMenuItem")[1]] != 1: + break self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN) self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN, delay=8) self.pyboy.tick(self.action_freq, render=True) diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index ced205e..0d1d947 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -69,6 +69,7 @@ def __init__( self.register_buffer( "unpack_shift", torch.tensor([6, 4, 2, 0], dtype=torch.uint8), persistent=False ) + self.register_buffer("binary_mask", torch.tensor([0] + [2**i for i in range(7)])) def encode_observations(self, observations): observations = unpack_batched_obs(observations, self.unflatten_context) @@ -98,6 +99,8 @@ def encode_observations(self, observations): .flatten() .int(), ).reshape(restored_shape) + # > 0 doesn't risk a type conversion + badges = (observations["badges"] >> self.binary_mask) > 0 image_observation = torch.cat((screen, visited_mask, global_map), dim=-1) if self.channels_last: @@ -112,11 +115,11 @@ def encode_observations(self, observations): one_hot(observations["direction"].long(), 4).float().squeeze(1), # one_hot(observations["reset_map_id"].long(), 0xF7).float().squeeze(1), one_hot(observations["battle_type"].long(), 4).float().squeeze(1), - # observations["cut_in_party"].float(), + observations["cut_in_party"].float(), # observations["x"].float(), # observations["y"].float(), # one_hot(observations["map_id"].long(), 0xF7).float().squeeze(1), - # one_hot(observations["badges"].long(), 8).float().squeeze(1), + badges.float().squeeze(1), ), dim=-1, )