From 47624ab6d827dff477c5b18ed0a7fad0d1ece01b Mon Sep 17 00:00:00 2001 From: tazlin Date: Wed, 20 Nov 2024 09:21:40 -0500 Subject: [PATCH] feat: per-epoch progress bar w/ info while training --- train.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index a10d3e05..fb78e7df 100644 --- a/train.py +++ b/train.py @@ -48,6 +48,7 @@ import optunahub from optuna.distributions import CategoricalDistribution, FloatDistribution, IntDistribution from optuna.terminator import EMMREvaluator, MedianErrorEvaluator, Terminator, TerminatorCallback +from tqdm import tqdm from hordelib.horde import HordeLib @@ -658,7 +659,10 @@ def objective(trial): total_loss = None best_epoch = best_loss = best_state_dict = None patience = trial.suggest_int("patience", MIN_PATIENCE, MAX_PATIENCE) if USE_PATIENCE else 0 - for epoch in range(NUM_EPOCHS): + epochs_since_best = 0 + + pbar = tqdm(range(NUM_EPOCHS), desc="Training Progress") + for epoch in pbar: # Train the model model.train() for data, labels in train_loader: @@ -688,12 +692,24 @@ def objective(trial): best_loss = total_loss best_epoch = epoch best_state_dict = model.state_dict() + epochs_since_best = 0 else: epochs_since_best = epoch - best_epoch if USE_PATIENCE and epochs_since_best >= patience: # Stop early, no improvement in awhile break + pbar.set_description( + f"input_size={input_size}, layers={layers}, output_size={output_size} " + f"batch_size={batch}, optimizer={optimizer_name}, lr={lr}, weight_decay={weight_decay}", + ) + + pbar.set_postfix( + loss=total_loss, + best_loss=best_loss, + epochs_since_best=epochs_since_best, + ) + # reload the best performing model we found model.load_state_dict(best_state_dict)