From 911f74a33c9d30e23d3bfdb93884c74e8d0c12b6 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Fri, 7 Jun 2024 14:10:49 -0400 Subject: [PATCH] Full reset every 25 resets --- config.yaml | 2 +- pokemonred_puffer/wrappers/exploration.py | 18 ++++++++---------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/config.yaml b/config.yaml index 02eebff..b1af678 100644 --- a/config.yaml +++ b/config.yaml @@ -137,7 +137,7 @@ wrappers: - stream_wrapper.StreamWrapper: user: thatguy - exploration.OnResetExplorationWrapper: - full_reset_frequency: 0 + full_reset_frequency: 25 rewards: baseline.BaselineRewardEnv: diff --git a/pokemonred_puffer/wrappers/exploration.py b/pokemonred_puffer/wrappers/exploration.py index edbe6b0..e01a6b5 100644 --- a/pokemonred_puffer/wrappers/exploration.py +++ b/pokemonred_puffer/wrappers/exploration.py @@ -101,9 +101,6 @@ def __init__(self, env: RedGymEnv, reward_config: pufferlib.namespace): self.full_reset_frequency = reward_config.full_reset_frequency self.counter = 0 - def step(self, action): - pass - def reset(self, *args, **kwargs): if self.counter % self.full_reset_frequency == 0: self.counter = 0 @@ -115,6 +112,7 @@ def reset(self, *args, **kwargs): self.cut_coords.clear() self.cut_tiles.clear() self.counter += 1 + return self.env.reset(*args, **kwargs) class OnResetLowerToFixedValueWrapper(gym.Wrapper): @@ -122,24 +120,23 @@ def __init__(self, env: RedGymEnv, reward_config: pufferlib.namespace): super().__init__(env) self.fixed_value = reward_config.fixed_value - def step(self, action): - pass - def reset(self, *args, **kwargs): self.env.unwrapped.seen_coords.update( - (k, self.fixed_value["coords"]) for k, v in self.env.unwrapped.seen_coords.items() + (k, self.fixed_value["coords"]) + for k, v in self.env.unwrapped.seen_coords.items() + if v > 0 ) self.env.unwrapped.seen_map_ids[self.env.unwrapped.seen_map_ids > 0] = self.fixed_value[ "map_ids" ] self.env.unwrapped.seen_npcs.update( - (k, self.fixed_value["npc"]) for k, v in self.env.unwrapped.seen_npcs.items() + (k, self.fixed_value["npc"]) for k, v in self.env.unwrapped.seen_npcs.items() if v > 0 ) self.env.unwrapped.cut_tiles.update( - (k, self.fixed_value["cut"]) for k, v in self.env.unwrapped.seen_npcs.items() + (k, self.fixed_value["cut"]) for k, v in self.env.unwrapped.seen_npcs.items() if v > 0 ) self.env.unwrapped.cut_coords.update( - (k, self.fixed_value["cut"]) for k, v in self.env.unwrapped.seen_npcs.items() + (k, self.fixed_value["cut"]) for k, v in self.env.unwrapped.seen_npcs.items() if v > 0 ) self.env.unwrapped.explore_map[self.env.unwrapped.explore_map > 0] = self.fixed_value[ "explore" @@ -147,3 +144,4 @@ def reset(self, *args, **kwargs): self.env.unwrapped.cut_explore_map[self.env.unwrapped.cut_explore_map > 0] = ( self.fixed_value["cut"] ) + return self.env.reset(*args, **kwargs)