Skip to content

Commit

Permalink
add back in badges obs, fix cut_if_next
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 2, 2024
1 parent 62f70df commit e8436fc
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
14 changes: 7 additions & 7 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,11 @@ def __init__(self, env_config: pufferlib.namespace):
"direction": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8),
# "reset_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8),
"battle_type": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8),
# "cut_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
"cut_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
# "x": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8),
# "y": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8),
# "map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8),
# "badges": spaces.Box(low=0, high=8, shape=(1,), dtype=np.uint8),
"badges": spaces.Box(low=0, high=0xFFFF, shape=(1,), dtype=np.uint8),
}
)

Expand Down Expand Up @@ -515,11 +515,11 @@ def _get_obs(self):
),
# "reset_map_id": np.array(self.read_m("wLastBlackoutMap"), dtype=np.uint8),
"battle_type": np.array(self.read_m("wIsInBattle") + 1, dtype=np.uint8),
# "cut_in_party": np.array(self.check_if_party_has_cut(), dtype=np.uint8),
"cut_in_party": np.array(self.check_if_party_has_cut(), dtype=np.uint8),
# "x": np.array(player_x, dtype=np.uint8),
# "y": np.array(player_y, dtype=np.uint8),
# "map_id": np.array(map_n, dtype=np.uint8),
# "badges": np.array(self.get_badges(), dtype=np.uint8),
"badges": np.array(self.read_m("wObtainedBadges"), dtype=np.uint8),
}

def set_perfect_iv_dvs(self):
Expand Down Expand Up @@ -647,9 +647,9 @@ def cut_if_next(self):
self.pyboy.tick(self.action_freq, render=True)
# scroll to pokemon
# 1 is the item index for pokemon
# for _ in range(24):
while self.pyboy.memory[self.pyboy.symbol_lookup("wCurrentMenuItem")[1]] != 1:
# break
for _ in range(24):
if self.pyboy.memory[self.pyboy.symbol_lookup("wCurrentMenuItem")[1]] != 1:
break
self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN)
self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN, delay=8)
self.pyboy.tick(self.action_freq, render=True)
Expand Down
7 changes: 5 additions & 2 deletions pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
self.register_buffer(
"unpack_shift", torch.tensor([6, 4, 2, 0], dtype=torch.uint8), persistent=False
)
self.register_buffer("binary_mask", torch.tensor([0] + [2**i for i in range(7)]))

def encode_observations(self, observations):
observations = unpack_batched_obs(observations, self.unflatten_context)
Expand Down Expand Up @@ -98,6 +99,8 @@ def encode_observations(self, observations):
.flatten()
.int(),
).reshape(restored_shape)
# > 0 doesn't risk a type conversion
badges = (observations["badges"] >> self.binary_mask) > 0

image_observation = torch.cat((screen, visited_mask, global_map), dim=-1)
if self.channels_last:
Expand All @@ -112,11 +115,11 @@ def encode_observations(self, observations):
one_hot(observations["direction"].long(), 4).float().squeeze(1),
# one_hot(observations["reset_map_id"].long(), 0xF7).float().squeeze(1),
one_hot(observations["battle_type"].long(), 4).float().squeeze(1),
# observations["cut_in_party"].float(),
observations["cut_in_party"].float(),
# observations["x"].float(),
# observations["y"].float(),
# one_hot(observations["map_id"].long(), 0xF7).float().squeeze(1),
# one_hot(observations["badges"].long(), 8).float().squeeze(1),
badges.float().squeeze(1),
),
dim=-1,
)
Expand Down

0 comments on commit e8436fc

Please sign in to comment.