diff --git a/config.yaml b/config.yaml index 601bbda..88e3fd4 100644 --- a/config.yaml +++ b/config.yaml @@ -133,12 +133,14 @@ wrappers: capacity: 1750 - exploration.OnResetExplorationWrapper: full_reset_frequency: 1 + jitter: 0 stream_only: - stream_wrapper.StreamWrapper: user: thatguy - exploration.OnResetExplorationWrapper: full_reset_frequency: 1 + jitter: 2 fixed_reset_value: - stream_wrapper.StreamWrapper: @@ -152,6 +154,7 @@ wrappers: explore: 0.33 - exploration.OnResetExplorationWrapper: full_reset_frequency: 25 + jitter: 0 rewards: baseline.BaselineRewardEnv: diff --git a/pokemonred_puffer/wrappers/exploration.py b/pokemonred_puffer/wrappers/exploration.py index cd76183..747a8ad 100644 --- a/pokemonred_puffer/wrappers/exploration.py +++ b/pokemonred_puffer/wrappers/exploration.py @@ -1,4 +1,5 @@ from collections import OrderedDict +import random import gymnasium as gym import numpy as np @@ -99,10 +100,11 @@ class OnResetExplorationWrapper(gym.Wrapper): def __init__(self, env: RedGymEnv, reward_config: pufferlib.namespace): super().__init__(env) self.full_reset_frequency = reward_config.full_reset_frequency + self.jitter = reward_config.jitter self.counter = 0 def reset(self, *args, **kwargs): - if self.counter % self.full_reset_frequency == 0: + if (self.counter + random.randint(0, self.jitter)) >= self.full_reset_frequency: self.counter = 0 self.env.unwrapped.explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32) self.env.unwrapped.cut_explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32)