Skip to content

Commit

Permalink
Full reset every 25 resets
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 7, 2024
1 parent d70da98 commit 911f74a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
2 changes: 1 addition & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ wrappers:
- stream_wrapper.StreamWrapper:
user: thatguy
- exploration.OnResetExplorationWrapper:
full_reset_frequency: 0
full_reset_frequency: 25

rewards:
baseline.BaselineRewardEnv:
Expand Down
18 changes: 8 additions & 10 deletions pokemonred_puffer/wrappers/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,6 @@ def __init__(self, env: RedGymEnv, reward_config: pufferlib.namespace):
self.full_reset_frequency = reward_config.full_reset_frequency
self.counter = 0

def step(self, action):
pass

def reset(self, *args, **kwargs):
if self.counter % self.full_reset_frequency == 0:
self.counter = 0
Expand All @@ -115,35 +112,36 @@ def reset(self, *args, **kwargs):
self.cut_coords.clear()
self.cut_tiles.clear()
self.counter += 1
return self.env.reset(*args, **kwargs)


class OnResetLowerToFixedValueWrapper(gym.Wrapper):
def __init__(self, env: RedGymEnv, reward_config: pufferlib.namespace):
super().__init__(env)
self.fixed_value = reward_config.fixed_value

def step(self, action):
pass

def reset(self, *args, **kwargs):
self.env.unwrapped.seen_coords.update(
(k, self.fixed_value["coords"]) for k, v in self.env.unwrapped.seen_coords.items()
(k, self.fixed_value["coords"])
for k, v in self.env.unwrapped.seen_coords.items()
if v > 0
)
self.env.unwrapped.seen_map_ids[self.env.unwrapped.seen_map_ids > 0] = self.fixed_value[
"map_ids"
]
self.env.unwrapped.seen_npcs.update(
(k, self.fixed_value["npc"]) for k, v in self.env.unwrapped.seen_npcs.items()
(k, self.fixed_value["npc"]) for k, v in self.env.unwrapped.seen_npcs.items() if v > 0
)
self.env.unwrapped.cut_tiles.update(
(k, self.fixed_value["cut"]) for k, v in self.env.unwrapped.seen_npcs.items()
(k, self.fixed_value["cut"]) for k, v in self.env.unwrapped.seen_npcs.items() if v > 0
)
self.env.unwrapped.cut_coords.update(
(k, self.fixed_value["cut"]) for k, v in self.env.unwrapped.seen_npcs.items()
(k, self.fixed_value["cut"]) for k, v in self.env.unwrapped.seen_npcs.items() if v > 0
)
self.env.unwrapped.explore_map[self.env.unwrapped.explore_map > 0] = self.fixed_value[
"explore"
]
self.env.unwrapped.cut_explore_map[self.env.unwrapped.cut_explore_map > 0] = (
self.fixed_value["cut"]
)
return self.env.reset(*args, **kwargs)

0 comments on commit 911f74a

Please sign in to comment.