diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 72fc2d4..31138a0 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -239,10 +239,11 @@ def __init__(self, env_config: pufferlib.namespace): # "reset_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), "battle_type": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8), "cut_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), - # "x": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), + # "x": spaces.Box(low=0, high=255, shape=(1,), dtype=np.u`int8), # "y": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), # "map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), - "badges": spaces.Box(low=0, high=0xFFFF, shape=(1,), dtype=np.uint8), + # "badges": spaces.Box(low=0, high=np.iinfo(np.uint16).max, shape=(1,), dtype=np.uint16), + "badges": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), } ) @@ -612,8 +613,8 @@ def run_action_on_emulator(self, action): self.action_hist[action] += 1 # press button then release after some steps # TODO: Add video saving logic - self.pyboy.send_input(VALID_ACTIONS[action]) - self.pyboy.send_input(VALID_RELEASE_ACTIONS[action], delay=8) + # self.pyboy.send_input(VALID_ACTIONS[action]) + # self.pyboy.send_input(VALID_RELEASE_ACTIONS[action], delay=8) self.pyboy.tick(self.action_freq, render=True) if self.read_bit(0xD803, 0): @@ -682,7 +683,7 @@ def cut_if_next(self): # scroll to pokemon # 1 is the item index for pokemon for _ in range(24): - if self.pyboy.memory[self.pyboy.symbol_lookup("wCurrentMenuItem")[1]] != 1: + if self.pyboy.memory[self.pyboy.symbol_lookup("wCurrentMenuItem")[1]] == 1: break self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN) self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN, delay=8) @@ -913,7 +914,7 @@ def read_event_bits(self): return self.pyboy.memory[addr : addr + EVENTS_FLAGS_LENGTH] def get_badges(self): - return self.read_m("wObtainedBadges").bit_count() + return self.read_short("wObtainedBadges").bit_count() def read_party(self): _, addr = self.pyboy.symbol_lookup("wPartySpecies") diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index 0d1d947..81f8a6e 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -69,7 +69,7 @@ def __init__( self.register_buffer( "unpack_shift", torch.tensor([6, 4, 2, 0], dtype=torch.uint8), persistent=False ) - self.register_buffer("binary_mask", torch.tensor([0] + [2**i for i in range(7)])) + self.register_buffer("binary_mask", torch.tensor([2**i for i in range(8)])) def encode_observations(self, observations): observations = unpack_batched_obs(observations, self.unflatten_context) @@ -100,7 +100,7 @@ def encode_observations(self, observations): .int(), ).reshape(restored_shape) # > 0 doesn't risk a type conversion - badges = (observations["badges"] >> self.binary_mask) > 0 + badges = (observations["badges"] & self.binary_mask) > 0 image_observation = torch.cat((screen, visited_mask, global_map), dim=-1) if self.channels_last: