Skip to content

Commit

Permalink
Add a time based early stop
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Oct 26, 2024
1 parent 02dbe6c commit baffed8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
7 changes: 7 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ train:
sqlite_wrapper: True
archive_states: True
swarm: True
early_stop:
# event name: minutes. If we dont satisfy each condition
# we early stop
# The defaults have a margin of error
EVENT_BEAT_BROCK: 45
EVENT_BEAT_MISTY: 90
EVENT_GOT_HM01: 180

wrappers:
empty:
Expand Down
12 changes: 11 additions & 1 deletion pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class CleanPuffeRL:
states: dict = field(default_factory=lambda: defaultdict(partial(deque, maxlen=1)))
event_tracker: dict = field(default_factory=lambda: {})
max_event_count: int = 0
early_stop: bool = False

def __post_init__(self):
seed_everything(self.config.seed, self.config.torch_deterministic)
Expand Down Expand Up @@ -318,6 +319,14 @@ def evaluate(self):
# pull a state within that list
new_state = random.choice(self.states[new_state_key])
"""
if self.config.train.early_stop:
for event, minutes in self.config.train.early_stop.values():
if any(event in key for key in self.state.keys()):
del self.config.train.early_stop[event]
elif self.profile.uptime > minutes * 60 and all(
event not in key for key in self.states.keys()
):
self.early_stop = True

# V2 implementation
# check if we have a new highest required_count with N save states available
Expand Down Expand Up @@ -366,6 +375,7 @@ def evaluate(self):
),
)
self.vecenv.async_reset()
# drain any INFO
key_set = self.event_tracker.keys()
while True:
# We connect each time just in case we block the wrappers
Expand Down Expand Up @@ -649,7 +659,7 @@ def calculate_loss(self, pg_loss, entropy_loss, v_loss):
self.optimizer.step()

def done_training(self):
return self.global_step >= self.config.total_timesteps
return self.early_stop or self.global_step >= self.config.total_timesteps

def __enter__(self):
return self
Expand Down

0 comments on commit baffed8

Please sign in to comment.