Skip to content

Commit

Permalink
fix badges again. Add rocket hideout specific event
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 6, 2024
1 parent 4f1f8cb commit 2c3e8fc
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 41 deletions.
31 changes: 15 additions & 16 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ debug:
env:
headless: False
stream_wrapper: False
init_state: cut3
init_state: cut
max_steps: 1_000_000
train:
device: cpu
Expand Down Expand Up @@ -157,25 +157,24 @@ rewards:
explore_npcs: 0.02
explore_hidden_objs: 0.02

baseline.RockTunnelReplicationEnv:
baseline.CutWithObjectRewardsEnv:
reward:
level: 1.0
exploration: 0.02
taught_cut: 10.0
event: 3.0
seen_pokemon: 4.0
caught_pokemon: 4.0
event: 1.0
bill_saved: 5.0
moves_obtained: 4.0
hm_count: 10.0
badges: 10.0
exploration: 0.02
cut_coords: 1.0
cut_tiles: 1.0
start_menu: 0.005
pokemon_menu: 0.05
stats_menu: 0.05
bag_menu: 0.05
pokecenter: 5.0
# Really an addition to event reward
badges: 2.0
bill_saved: 2.0
start_menu: 0.0
pokemon_menu: 0.0
stats_menu: 0.0
bag_menu: 0.0
taught_cut: 10.0
explore_npcs: 0.02
explore_hidden_objs: 0.02
rocket_hideout_found: 5.0



Expand Down
15 changes: 14 additions & 1 deletion pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,12 @@ def register_hooks(self):
self.pyboy.hook_register(
None, "CheckForHiddenObject.foundMatchingObject", self.hidden_object_hook, None
)
"""
_, addr = self.pyboy.symbol_lookup("IsSpriteOrSignInFrontOfPlayer.retry")
self.pyboy.hook_register(
None, addr-1, self.sign_hook, None
)
"""
self.pyboy.hook_register(None, "HandleBlackOut", self.blackout_hook, None)
self.pyboy.hook_register(None, "SetLastBlackoutMap.done", self.blackout_update_hook, None)
# self.pyboy.hook_register(None, "UsedCut.nothingToCut", self.cut_hook, context=True)
Expand All @@ -306,6 +312,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] =
self.init_mem()
# We only init seen hidden objs once cause they can only be found once!
self.seen_hidden_objs = {}
self.seen_signs = {}
if options.get("state", None) is not None:
self.pyboy.load_state(io.BytesIO(options["state"]))
self.reset_count += 1
Expand Down Expand Up @@ -553,7 +560,7 @@ def _get_obs(self):
# "x": np.array(player_x, dtype=np.uint8),
# "y": np.array(player_y, dtype=np.uint8),
# "map_id": np.array(map_n, dtype=np.uint8),
"badges": np.array(self.read_m("wObtainedBadges"), dtype=np.uint8),
"badges": np.array(self.read_short("wObtainedBadges").bit_count(), dtype=np.uint8),
}

def set_perfect_iv_dvs(self):
Expand Down Expand Up @@ -711,6 +718,12 @@ def cut_if_next(self):
self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_A, delay=8)
self.pyboy.tick(4 * self.action_freq, render=True)

def sign_hook(self, *args, **kwargs):
sign_id = self.pyboy.memory[self.pyboy.symbol_lookup("hSpriteIndexOrTextID")[1]]
map_id = self.pyboy.memory[self.pyboy.symbol_lookup("wCurMap")[1]]
# We will store this by map id, y, x,
self.seen_hidden_objs[(map_id, sign_id)] = 1

def hidden_object_hook(self, *args, **kwargs):
hidden_object_id = self.pyboy.memory[self.pyboy.symbol_lookup("wHiddenObjectIndex")[1]]
map_id = self.pyboy.memory[self.pyboy.symbol_lookup("wCurMap")[1]]
Expand Down
4 changes: 1 addition & 3 deletions pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def __init__(
self.register_buffer(
"unpack_shift", torch.tensor([6, 4, 2, 0], dtype=torch.uint8), persistent=False
)
self.register_buffer("binary_mask", torch.tensor([2**i for i in range(8)]))

def encode_observations(self, observations):
observations = unpack_batched_obs(observations, self.unflatten_context)
Expand Down Expand Up @@ -99,8 +98,7 @@ def encode_observations(self, observations):
.flatten()
.int(),
).reshape(restored_shape)
# > 0 doesn't risk a type conversion
badges = (observations["badges"] & self.binary_mask) > 0
badges = (torch.arange(8) + 1) <= observations["badges"]

image_observation = torch.cat((screen, visited_mask, global_map), dim=-1)
if self.channels_last:
Expand Down
48 changes: 27 additions & 21 deletions pokemonred_puffer/rewards/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
RedGymEnv,
)

import numpy as np

MUSEUM_TICKET = (0xD754, 0)


Expand Down Expand Up @@ -165,16 +163,30 @@ def get_levels_reward(self):
return 15 + (self.max_level_sum - 15) / 4


class RockTunnelReplicationEnv(BaselineRewardEnv):
class CutWithObjectRewardsEnv(BaselineRewardEnv):
def get_game_state_reward(self):
return {
"level": self.reward_config["level"] * self.get_levels_reward(),
"exploration": self.reward_config["exploration"] * sum(self.seen_coords.values()),
"taught_cut": self.reward_config["taught_cut"] * int(self.taught_cut),
"event": self.reward_config["event"] * self.update_max_event_rew(),
"seen_pokemon": self.reward_config["seen_pokemon"] * np.sum(self.seen_pokemon),
"caught_pokemon": self.reward_config["caught_pokemon"] * np.sum(self.caught_pokemon),
"moves_obtained": self.reward_config["moves_obtained"] * np.sum(self.moves_obtained),
"met_bill": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F1, 0)),
"used_cell_separator_on_bill": (
self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 3))
),
"ss_ticket": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 4)),
"met_bill_2": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 5)),
"bill_said_use_cell_separator": (
self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 6))
),
"left_bills_house_after_helping": (
self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 7))
),
"moves_obtained": self.reward_config["moves_obtained"] * sum(self.moves_obtained),
"hm_count": self.reward_config["hm_count"] * self.get_hm_count(),
"badges": self.reward_config["badges"] * self.get_badges(),
"exploration": self.reward_config["exploration"] * sum(self.seen_coords.values()),
"explore_npcs": self.reward_config["explore_npcs"] * sum(self.seen_npcs.values()),
"explore_hidden_objs": (
self.reward_config["explore_hidden_objs"] * sum(self.seen_hidden_objs.values())
),
"cut_coords": self.reward_config["cut_coords"] * sum(self.cut_coords.values()),
"cut_tiles": self.reward_config["cut_tiles"] * sum(self.cut_tiles),
"start_menu": (
Expand All @@ -187,18 +199,12 @@ def get_game_state_reward(self):
self.reward_config["stats_menu"] * self.seen_stats_menu * int(self.taught_cut)
),
"bag_menu": self.reward_config["bag_menu"] * self.seen_bag_menu * int(self.taught_cut),
# "pokecenter": self.reward_config["pokecenter"] * np.sum(self.pokecenters),
"badges": self.reward_config["badges"] * self.get_badges(),
"met_bill": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F1, 0)),
"used_cell_separator_on_bill": self.reward_config["bill_saved"]
* int(self.read_bit(0xD7F2, 3)),
"ss_ticket": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 4)),
"met_bill_2": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 5)),
"bill_said_use_cell_separator": self.reward_config["bill_saved"]
* int(self.read_bit(0xD7F2, 6)),
"left_bills_house_after_helping": self.reward_config["bill_saved"]
* int(self.read_bit(0xD7F2, 7)),
"rival3": self.reward_config["event"] * int(self.read_m(0xD665) == 4),
"taught_cut": self.reward_config["taught_cut"] * int(self.taught_cut),
"seen_pokemon": self.reward_config["seen_pokemon"] * sum(self.seen_pokemon),
"caught_pokemon": self.reward_config["caught_pokemon"] * sum(self.caught_pokemon),
"level": self.reward_config["level"] * self.get_levels_reward(),
"rocket_hideout_found": self.reward_config["rocket_hideout_found"]
* int(self.read_bit(0xD77E, 1)),
}

def get_levels_reward(self):
Expand Down

0 comments on commit 2c3e8fc

Please sign in to comment.