Skip to content

Commit

Permalink
Fix cut script again
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 3, 2024
1 parent d8568b1 commit c810a59
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
13 changes: 7 additions & 6 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,11 @@ def __init__(self, env_config: pufferlib.namespace):
# "reset_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8),
"battle_type": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8),
"cut_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
# "x": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8),
# "x": spaces.Box(low=0, high=255, shape=(1,), dtype=np.u`int8),
# "y": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8),
# "map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8),
"badges": spaces.Box(low=0, high=0xFFFF, shape=(1,), dtype=np.uint8),
# "badges": spaces.Box(low=0, high=np.iinfo(np.uint16).max, shape=(1,), dtype=np.uint16),
"badges": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8),
}
)

Expand Down Expand Up @@ -612,8 +613,8 @@ def run_action_on_emulator(self, action):
self.action_hist[action] += 1
# press button then release after some steps
# TODO: Add video saving logic
self.pyboy.send_input(VALID_ACTIONS[action])
self.pyboy.send_input(VALID_RELEASE_ACTIONS[action], delay=8)
# self.pyboy.send_input(VALID_ACTIONS[action])
# self.pyboy.send_input(VALID_RELEASE_ACTIONS[action], delay=8)
self.pyboy.tick(self.action_freq, render=True)

if self.read_bit(0xD803, 0):
Expand Down Expand Up @@ -682,7 +683,7 @@ def cut_if_next(self):
# scroll to pokemon
# 1 is the item index for pokemon
for _ in range(24):
if self.pyboy.memory[self.pyboy.symbol_lookup("wCurrentMenuItem")[1]] != 1:
if self.pyboy.memory[self.pyboy.symbol_lookup("wCurrentMenuItem")[1]] == 1:
break
self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN)
self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN, delay=8)
Expand Down Expand Up @@ -913,7 +914,7 @@ def read_event_bits(self):
return self.pyboy.memory[addr : addr + EVENTS_FLAGS_LENGTH]

def get_badges(self):
return self.read_m("wObtainedBadges").bit_count()
return self.read_short("wObtainedBadges").bit_count()

def read_party(self):
_, addr = self.pyboy.symbol_lookup("wPartySpecies")
Expand Down
4 changes: 2 additions & 2 deletions pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
self.register_buffer(
"unpack_shift", torch.tensor([6, 4, 2, 0], dtype=torch.uint8), persistent=False
)
self.register_buffer("binary_mask", torch.tensor([0] + [2**i for i in range(7)]))
self.register_buffer("binary_mask", torch.tensor([2**i for i in range(8)]))

def encode_observations(self, observations):
observations = unpack_batched_obs(observations, self.unflatten_context)
Expand Down Expand Up @@ -100,7 +100,7 @@ def encode_observations(self, observations):
.int(),
).reshape(restored_shape)
# > 0 doesn't risk a type conversion
badges = (observations["badges"] >> self.binary_mask) > 0
badges = (observations["badges"] & self.binary_mask) > 0

image_observation = torch.cat((screen, visited_mask, global_map), dim=-1)
if self.channels_last:
Expand Down

0 comments on commit c810a59

Please sign in to comment.