Skip to content

Commit

Permalink
Add jitter to reset exploration wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 15, 2024
1 parent 79b49e4 commit b4a9979
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
3 changes: 3 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,14 @@ wrappers:
capacity: 1750
- exploration.OnResetExplorationWrapper:
full_reset_frequency: 1
jitter: 0

stream_only:
- stream_wrapper.StreamWrapper:
user: thatguy
- exploration.OnResetExplorationWrapper:
full_reset_frequency: 1
jitter: 2

fixed_reset_value:
- stream_wrapper.StreamWrapper:
Expand All @@ -152,6 +154,7 @@ wrappers:
explore: 0.33
- exploration.OnResetExplorationWrapper:
full_reset_frequency: 25
jitter: 0

rewards:
baseline.BaselineRewardEnv:
Expand Down
4 changes: 3 additions & 1 deletion pokemonred_puffer/wrappers/exploration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import OrderedDict
import random
import gymnasium as gym
import numpy as np

Expand Down Expand Up @@ -99,10 +100,11 @@ class OnResetExplorationWrapper(gym.Wrapper):
def __init__(self, env: RedGymEnv, reward_config: pufferlib.namespace):
super().__init__(env)
self.full_reset_frequency = reward_config.full_reset_frequency
self.jitter = reward_config.jitter
self.counter = 0

def reset(self, *args, **kwargs):
if self.counter % self.full_reset_frequency == 0:
if (self.counter + random.randint(0, self.jitter)) >= self.full_reset_frequency:
self.counter = 0
self.env.unwrapped.explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32)
self.env.unwrapped.cut_explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32)
Expand Down

0 comments on commit b4a9979

Please sign in to comment.