Skip to content

Commit

Permalink
map id rewards
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Aug 25, 2024
1 parent 3cdfaec commit 64e98ec
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 1 deletion.
29 changes: 29 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ env:
exploration_inc: 1.0
exploration_max: 1.0
max_steps_scaling: 0 # 0.2 # every 10 events or items gained, multiply max_steps by 2
map_id_scalefactor: 5.0 # multiply map ids whose events have not been completed by 5




Expand Down Expand Up @@ -295,6 +297,33 @@ rewards:
a_press: 0.0 # 0.00001
explore_warps: 0.05
use_surf: 0.05

baseline.ObjectRewardRequiredEventsMapIds:
reward:
event: 1.0
seen_pokemon: 4.0
caught_pokemon: 4.0
moves_obtained: 4.0
hm_count: 10.0
level: 1.0
badges: 5.0
cut_coords: 0.0
cut_tiles: 0.0
start_menu: 0.0
pokemon_menu: 0.0
stats_menu: 0.0
bag_menu: 0.0
explore_hidden_objs: 0.01
explore_signs: 0.015
seen_action_bag_menu: 0.0
required_event: 5.0
required_item: 5.0
useful_item: 1.0
pokecenter_heal: 0.2
exploration: 0.02
a_press: 0.0 # 0.00001
explore_warps: 0.01
use_surf: 0.5



Expand Down
33 changes: 33 additions & 0 deletions pokemonred_puffer/data/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,36 @@ class MapIds(Enum):
0x10, # Route 10 (Rock Tunnel)
0xE9, # Silph Co 9F (Heal station)
}

MAP_ID_COMPLETION_EVENTS = {
MapIds.PEWTER_GYM: "EVENT_BEAT_BROCK",
MapIds.CERULEAN_GYM: "EVENT_BEAT_MISTY",
MapIds.VERMILION_GYM: "EVENT_BEAT_LT_SURGE",
MapIds.CELADON_GYM: "EVENT_BEAT_ERIKA",
MapIds.SAFFRON_GYM: "EVENT_BEAT_SABRINA",
MapIds.FUCHSIA_GYM: "EVENT_BEAT_KOGA",
MapIds.CINNABAR_GYM: "EVENT_BEAT_BLAINE",
MapIds.VIRIDIAN_GYM: "EVENT_BEAT_VIRIDIAN_GYM_GIOVANNI",
MapIds.GAME_CORNER: "EVENT_BEAT_ROCKET_HIDEOUT_GIOVANNI",
MapIds.ROCKET_HIDEOUT_B1F: "EVENT_BEAT_ROCKET_HIDEOUT_GIOVANNI",
MapIds.ROCKET_HIDEOUT_B2F: "EVENT_BEAT_ROCKET_HIDEOUT_GIOVANNI",
MapIds.ROCKET_HIDEOUT_B3F: "EVENT_BEAT_ROCKET_HIDEOUT_GIOVANNI",
MapIds.ROCKET_HIDEOUT_B4F: "EVENT_BEAT_ROCKET_HIDEOUT_GIOVANNI",
MapIds.ROCKET_HIDEOUT_ELEVATOR: "EVENT_BEAT_ROCKET_HIDEOUT_GIOVANNI",
MapIds.SILPH_CO_1F: "EVENT_BEAT_SILPH_CO_GIOVANNI",
MapIds.SILPH_CO_2F: "EVENT_BEAT_SILPH_CO_GIOVANNI",
MapIds.SILPH_CO_3F: "EVENT_BEAT_SILPH_CO_GIOVANNI",
MapIds.SILPH_CO_4F: "EVENT_BEAT_SILPH_CO_GIOVANNI",
MapIds.SILPH_CO_5F: "EVENT_BEAT_SILPH_CO_GIOVANNI",
MapIds.SILPH_CO_6F: "EVENT_BEAT_SILPH_CO_GIOVANNI",
MapIds.SILPH_CO_7F: "EVENT_BEAT_SILPH_CO_GIOVANNI",
MapIds.SILPH_CO_8F: "EVENT_BEAT_SILPH_CO_GIOVANNI",
MapIds.SILPH_CO_9F: "EVENT_BEAT_SILPH_CO_GIOVANNI",
MapIds.SILPH_CO_10F: "EVENT_BEAT_SILPH_CO_GIOVANNI",
MapIds.SILPH_CO_11F: "EVENT_BEAT_SILPH_CO_GIOVANNI",
MapIds.SILPH_CO_ELEVATOR: "EVENT_BEAT_SILPH_CO_GIOVANNI",
MapIds.POKEMON_MANSION_1F: "HS_POKEMON_MANSION_B1F_ITEM_5",
MapIds.POKEMON_MANSION_2F: "HS_POKEMON_MANSION_B1F_ITEM_5",
MapIds.POKEMON_MANSION_3F: "HS_POKEMON_MANSION_B1F_ITEM_5",
MapIds.POKEMON_MANSION_B1F: "HS_POKEMON_MANSION_B1F_ITEM_5",
}
26 changes: 25 additions & 1 deletion pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
USEFUL_ITEMS,
Items,
)
from pokemonred_puffer.data.map import MapIds
from pokemonred_puffer.data.map import MAP_ID_COMPLETION_EVENTS, MapIds
from pokemonred_puffer.data.missable_objects import MissableFlags
from pokemonred_puffer.data.party import PartyMons
from pokemonred_puffer.data.strength_puzzles import STRENGTH_SOLUTIONS
Expand Down Expand Up @@ -130,6 +130,7 @@ def __init__(self, env_config: pufferlib.namespace):
self.exploration_inc = env_config.exploration_inc
self.exploration_max = env_config.exploration_max
self.max_steps_scaling = env_config.max_steps_scaling
self.map_id_scalefactor = env_config.map_id_scalefactor
self.action_space = ACTION_SPACE

# Obs space-related. TODO: avoid hardcoding?
Expand Down Expand Up @@ -343,6 +344,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] =
self.caught_pokemon.fill(0)
self.moves_obtained.fill(0)
self.explore_map *= 0
self.reward_explore_map *= 0
self.cut_explore_map *= 0
self.reset_mem()

Expand Down Expand Up @@ -388,6 +390,7 @@ def init_mem(self):
# All map ids have the same size, right?
self.seen_coords: dict[int, dict[tuple[int, int, int], int]] = {}
self.explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32)
self.reward_explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32)
self.cut_explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32)
self.seen_map_ids = np.zeros(256)
self.seen_npcs = {}
Expand Down Expand Up @@ -1466,6 +1469,10 @@ def update_seen_coords(self):
self.explore_map[local_to_global(y_pos, x_pos, map_n)] + inc,
self.exploration_max,
)
self.reward_explore_map[local_to_global(y_pos, x_pos, map_n)] = min(
self.explore_map[local_to_global(y_pos, x_pos, map_n)] + inc,
self.exploration_max,
) * self.map_id_scaling(map_n)
# self.seen_global_coords[local_to_global(y_pos, x_pos, map_n)] = 1
self.seen_map_ids[map_n] = 1

Expand Down Expand Up @@ -1687,3 +1694,20 @@ def get_events_sum(self):
- int(self.read_bit(*MUSEUM_TICKET)),
0,
)

def map_id_scaling(self, map_n: int) -> float:
map_id = MapIds(map_n)
if map_id not in MAP_ID_COMPLETION_EVENTS:
return 1.0

event_or_missable = MAP_ID_COMPLETION_EVENTS[map_id]
if (
event_or_missable.startswith("EVENT_")
and not self.events.get_event(event_or_missable)
or (
event_or_missable.startswith("HS_")
and not self.missables.get_missable(event_or_missable)
)
):
return self.map_id_scalefactor
return 1.0
62 changes: 62 additions & 0 deletions pokemonred_puffer/rewards/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,65 @@ def get_levels_reward(self):
return self.max_level_sum
else:
return 15 + (self.max_level_sum - 15) / 4


class ObjectRewardRequiredEventsMapIds(BaselineRewardEnv):
def get_game_state_reward(self):
_, wBagItems = self.pyboy.symbol_lookup("wBagItems")
numBagItems = self.read_m("wNumBagItems")
bag_item_ids = set(self.pyboy.memory[wBagItems : wBagItems + 2 * numBagItems : 2])

return (
{
"event": self.reward_config["event"] * self.update_max_event_rew(),
"seen_pokemon": self.reward_config["seen_pokemon"] * sum(self.seen_pokemon),
"caught_pokemon": self.reward_config["caught_pokemon"] * sum(self.caught_pokemon),
"moves_obtained": self.reward_config["moves_obtained"] * sum(self.moves_obtained),
"hm_count": self.reward_config["hm_count"] * self.get_hm_count(),
"level": self.reward_config["level"] * self.get_levels_reward(),
"badges": self.reward_config["badges"] * self.get_badges(),
"cut_coords": self.reward_config["cut_coords"] * sum(self.cut_coords.values()),
"cut_tiles": self.reward_config["cut_tiles"] * sum(self.cut_tiles.values()),
"start_menu": self.reward_config["start_menu"] * self.seen_start_menu,
"pokemon_menu": self.reward_config["pokemon_menu"] * self.seen_pokemon_menu,
"stats_menu": self.reward_config["stats_menu"] * self.seen_stats_menu,
"bag_menu": self.reward_config["bag_menu"] * self.seen_bag_menu,
"explore_hidden_objs": sum(self.seen_hidden_objs.values()),
"explore_signs": sum(self.seen_signs.values())
* self.reward_config["explore_signs"],
"seen_action_bag_menu": self.seen_action_bag_menu
* self.reward_config["seen_action_bag_menu"],
"pokecenter_heal": self.pokecenter_heal * self.reward_config["pokecenter_heal"],
"rival3": self.reward_config["required_event"]
* int(self.read_m("wSSAnne2FCurScript") == 4),
"game_corner_rocket": self.reward_config["required_event"]
* float(self.missables.get_missable("HS_GAME_CORNER_ROCKET")),
"saffron_guard": self.reward_config["required_event"]
* float(self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK")),
"a_press": len(self.a_press) * self.reward_config["a_press"],
"warps": len(self.seen_warps) * self.reward_config["explore_warps"],
"use_surf": self.reward_config["use_surf"] * self.use_surf,
"exploration": self.reward_config["exploration"] * np.sum(self.reward_explore_map),
}
| {
event: self.reward_config["required_event"] * float(self.events.get_event(event))
for event in REQUIRED_EVENTS
}
| {
item.name: self.reward_config["required_item"] * float(item.value in bag_item_ids)
for item in REQUIRED_ITEMS
}
| {
item.name: self.reward_config["useful_item"] * float(item.value in bag_item_ids)
for item in USEFUL_ITEMS
}
)

def get_levels_reward(self):
party_size = self.read_m("wPartyCount")
party_levels = [self.read_m(f"wPartyMon{i+1}Level") for i in range(party_size)]
self.max_level_sum = max(self.max_level_sum, sum(party_levels))
if self.max_level_sum < 15:
return self.max_level_sum
else:
return 15 + (self.max_level_sum - 15) / 4

0 comments on commit 64e98ec

Please sign in to comment.