Skip to content

Commit

Permalink
feat: per-epoch progress bar w/ info while training
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Nov 22, 2024
1 parent aea1d9a commit 47624ab
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import optunahub
from optuna.distributions import CategoricalDistribution, FloatDistribution, IntDistribution
from optuna.terminator import EMMREvaluator, MedianErrorEvaluator, Terminator, TerminatorCallback
from tqdm import tqdm

from hordelib.horde import HordeLib

Expand Down Expand Up @@ -658,7 +659,10 @@ def objective(trial):
total_loss = None
best_epoch = best_loss = best_state_dict = None
patience = trial.suggest_int("patience", MIN_PATIENCE, MAX_PATIENCE) if USE_PATIENCE else 0
for epoch in range(NUM_EPOCHS):
epochs_since_best = 0

pbar = tqdm(range(NUM_EPOCHS), desc="Training Progress")
for epoch in pbar:
# Train the model
model.train()
for data, labels in train_loader:
Expand Down Expand Up @@ -688,12 +692,24 @@ def objective(trial):
best_loss = total_loss
best_epoch = epoch
best_state_dict = model.state_dict()
epochs_since_best = 0
else:
epochs_since_best = epoch - best_epoch
if USE_PATIENCE and epochs_since_best >= patience:
# Stop early, no improvement in awhile
break

pbar.set_description(
f"input_size={input_size}, layers={layers}, output_size={output_size} "
f"batch_size={batch}, optimizer={optimizer_name}, lr={lr}, weight_decay={weight_decay}",
)

pbar.set_postfix(
loss=total_loss,
best_loss=best_loss,
epochs_since_best=epochs_since_best,
)

# reload the best performing model we found
model.load_state_dict(best_state_dict)

Expand Down

0 comments on commit 47624ab

Please sign in to comment.