diff --git a/config.yaml b/config.yaml index cd7f4c8..2f01941 100644 --- a/config.yaml +++ b/config.yaml @@ -141,6 +141,7 @@ train: EVENT_BEAT_BROCK: 30 EVENT_BEAT_MISTY: 90 EVENT_GOT_HM01: 180 + one_epoch: True wrappers: empty: diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 8876c2a..373f5bc 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -682,7 +682,15 @@ def calculate_loss(self, pg_loss, entropy_loss, v_loss): self.optimizer.step() def done_training(self): - return self.early_stop or self.global_step >= self.config.total_timesteps + return ( + self.early_stop + or self.global_step >= self.config.total_timesteps + or ( + self.config.one_epoch + and self.states + and any("EVENT_BEAT_CHAMPION_RIVAL" in key for key in self.states.keys()) + ) + ) def __enter__(self): return self