diff --git a/config.yaml b/config.yaml index d1bf261..ec12076 100644 --- a/config.yaml +++ b/config.yaml @@ -134,6 +134,13 @@ train: sqlite_wrapper: True archive_states: True swarm: True + early_stop: + # event name: minutes. If we dont satisfy each condition + # we early stop + # The defaults have a margin of error + EVENT_BEAT_BROCK: 45 + EVENT_BEAT_MISTY: 90 + EVENT_GOT_HM01: 180 wrappers: empty: diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 864e98e..086cca9 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -142,6 +142,7 @@ class CleanPuffeRL: states: dict = field(default_factory=lambda: defaultdict(partial(deque, maxlen=1))) event_tracker: dict = field(default_factory=lambda: {}) max_event_count: int = 0 + early_stop: bool = False def __post_init__(self): seed_everything(self.config.seed, self.config.torch_deterministic) @@ -318,6 +319,14 @@ def evaluate(self): # pull a state within that list new_state = random.choice(self.states[new_state_key]) """ + if self.config.train.early_stop: + for event, minutes in self.config.train.early_stop.values(): + if any(event in key for key in self.state.keys()): + del self.config.train.early_stop[event] + elif self.profile.uptime > minutes * 60 and all( + event not in key for key in self.states.keys() + ): + self.early_stop = True # V2 implementation # check if we have a new highest required_count with N save states available @@ -366,6 +375,7 @@ def evaluate(self): ), ) self.vecenv.async_reset() + # drain any INFO key_set = self.event_tracker.keys() while True: # We connect each time just in case we block the wrappers @@ -649,7 +659,7 @@ def calculate_loss(self, pg_loss, entropy_loss, v_loss): self.optimizer.step() def done_training(self): - return self.global_step >= self.config.total_timesteps + return self.early_stop or self.global_step >= self.config.total_timesteps def __enter__(self): return self