Skip to content

Commit

Permalink
Wrappers for reset
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 7, 2024
1 parent 8898c63 commit ff94a87
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 10 deletions.
20 changes: 20 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,35 @@ wrappers:
bag_menu: 0.998
action_bag_menu: 0.998
forgetting_frequency: 10
- exploration.OnResetExplorationWrapper:
full_reset_frequency: 0

finite_coords:
- stream_wrapper.StreamWrapper:
user: thatguy
- exploration.MaxLengthWrapper:
capacity: 1750
- exploration.OnResetExplorationWrapper:
full_reset_frequency: 0

stream_only:
- stream_wrapper.StreamWrapper:
user: thatguy
- exploration.OnresetLowerToFixedValueWrapper:
fixed_value:
coords: 0.33
map_ids: 0.33
npc: 0.33
cut: 0.33
explore: 0.33
- exploration.OnResetExplorationWrapper:
full_reset_frequency: 0

fixed_reset_value:
- stream_wrapper.StreamWrapper:
user: thatguy
- exploration.OnResetExplorationWrapper:
full_reset_frequency: 0

rewards:
baseline.BaselineRewardEnv:
Expand Down Expand Up @@ -177,6 +196,7 @@ rewards:
stats_menu: 0.0
bag_menu: 0.0
rocket_hideout_found: 5.0
explore_hidden_objs: 0.02



Expand Down
9 changes: 0 additions & 9 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,6 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] =
with open(self.init_state_path, "rb") as f:
self.pyboy.load_state(f)
self.reset_count = 0
self.explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32)
self.cut_explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32)
self.base_event_flags = sum(
self.read_m(i).bit_count()
for i in range(EVENT_FLAGS_START, EVENT_FLAGS_START + EVENTS_FLAGS_LENGTH)
Expand Down Expand Up @@ -396,13 +394,6 @@ def init_mem(self):
self.seen_action_bag_menu = 0

def reset_mem(self):
self.seen_coords.update((k, 0) for k, _ in self.seen_coords.items())
self.seen_map_ids *= 0
self.seen_npcs.update((k, 0) for k, _ in self.seen_npcs.items())

self.cut_coords.update((k, 0) for k, _ in self.cut_coords.items())
self.cut_tiles.update((k, 0) for k, _ in self.cut_tiles.items())

self.seen_start_menu = 0
self.seen_pokemon_menu = 0
self.seen_stats_menu = 0
Expand Down
2 changes: 2 additions & 0 deletions pokemonred_puffer/rewards/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ def get_game_state_reward(self):
"rival3": self.reward_config["event"] * int(self.read_m(0xD665) == 4),
"rocket_hideout_found": self.reward_config["rocket_hideout_found"]
* int(self.read_bit(0xD77E, 1)),
"explore_hidden_objs": sum(self.seen_hidden_objs.values())
* self.reward_config["explore_hidden_objs"],
}

def get_levels_reward(self):
Expand Down
51 changes: 50 additions & 1 deletion pokemonred_puffer/wrappers/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pufferlib
from pokemonred_puffer.environment import RedGymEnv
from pokemonred_puffer.global_map import local_to_global
from pokemonred_puffer.global_map import GLOBAL_MAP_SHAPE, local_to_global


class LRUCache:
Expand Down Expand Up @@ -93,3 +93,52 @@ def step(self, action):
def reset(self, *args, **kwargs):
self.cache.clear()
return self.env.reset(*args, **kwargs)


class OnResetExplorationWrapper(gym.Wrapper):
def __init__(self, env: RedGymEnv, reward_config: pufferlib.namespace):
super().__init__(env)
self.full_reset_frequency = reward_config.full_reset_frequency
self.counter = 0

def step(self, action):
pass

def reset(self, *args, **kwargs):
if self.counter % self.full_reset_frequency == 0:
self.counter = 0
self.explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32)
self.cut_explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32)
self.seen_coords.update((k, 0) for k, _ in self.seen_coords.items())
self.seen_map_ids *= 0
self.seen_npcs.update((k, 0) for k, _ in self.seen_npcs.items())

self.cut_coords.update((k, 0) for k, _ in self.cut_coords.items())
self.cut_tiles.update((k, 0) for k, _ in self.cut_tiles.items())
self.counter += 1


class OnResetLowerToFixedValueWrapper(gym.Wrapper):
def __init__(self, env: RedGymEnv, reward_config: pufferlib.namespace):
super().__init__(env)
self.fixed_value = reward_config.fixed_value

def step(self, action):
pass

def reset(self, *args, **kwargs):
self.env.unwrapped.seen_coords.update(
(k, self.fixed_value["coords"]) for k, v in self.env.unwrapped.seen_coords.items()
)
self.env.unwrapped.seen_map_ids *= self.fixed_value["map_ids"]
self.env.unwrapped.seen_npcs.update(
(k, self.fixed_value["npc"]) for k, v in self.env.unwrapped.seen_npcs.items()
)
self.env.unwrapped.cut_tiles.update(
(k, self.fixed_value["cut"]) for k, v in self.env.unwrapped.seen_npcs.items()
)
self.env.unwrapped.cut_coords.update(
(k, self.fixed_value["cut"]) for k, v in self.env.unwrapped.seen_npcs.items()
)
self.env.unwrapped.explore_map = self.fixed_value["explore"]
self.cut_explore_map = self.fixed_value["cut"]

0 comments on commit ff94a87

Please sign in to comment.