Skip to content

Commit

Permalink
use_patience
Browse files Browse the repository at this point in the history
  • Loading branch information
db0 committed Nov 21, 2024
1 parent ed3a0d4 commit 3a6aea4
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,14 @@

# Number of trials to run.
# Each trial generates a new neural network topology with new hyper parameters and trains it.
NUMBER_OF_STUDY_TRIALS = 150
NUMBER_OF_STUDY_TRIALS = 100

# Hyper parameter search bounds
NUM_EPOCHS = 2000
# Patience is an custom terminator that stops a training if no improvement has happened in this many epochs
USE_PATIENCE = False
MIN_PATIENCE = 25
MAX_PATIENCE = 300 # if no improvement in this many epochs, stop early
MAX_PATIENCE = 300
MIN_NUMBER_OF_EPOCHS = 50
MIN_HIDDEN_LAYERS = 3
MAX_HIDDEN_LAYERS = 9
Expand Down Expand Up @@ -618,7 +620,7 @@ def objective(trial):

total_loss = None
best_epoch = best_loss = best_state_dict = None
patience = trial.suggest_int("patience", MIN_PATIENCE, MAX_PATIENCE)
patience = trial.suggest_int("patience", MIN_PATIENCE, MAX_PATIENCE) if USE_PATIENCE else 0
for epoch in range(NUM_EPOCHS):
# Train the model
model.train()
Expand Down Expand Up @@ -651,7 +653,7 @@ def objective(trial):
best_state_dict = model.state_dict()
else:
epochs_since_best = epoch - best_epoch
if epochs_since_best >= patience:
if USE_PATIENCE and epochs_since_best >= patience:
# Stop early, no improvement in awhile
break

Expand Down

0 comments on commit 3a6aea4

Please sign in to comment.