From e651fe04b0dd5404d0297b2245920a979870e045 Mon Sep 17 00:00:00 2001 From: tazlin Date: Fri, 22 Nov 2024 14:45:24 -0500 Subject: [PATCH] fix: extra logging and optional progress bars to train.py --- train.py | 66 +++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 51 insertions(+), 15 deletions(-) diff --git a/train.py b/train.py index e36bd8c8..cf4230f9 100644 --- a/train.py +++ b/train.py @@ -39,6 +39,7 @@ import optunahub import torch import torch.nn as nn +from loguru import logger from optuna.distributions import CategoricalDistribution, FloatDistribution, IntDistribution from optuna.terminator import EMMREvaluator, MedianErrorEvaluator, Terminator, TerminatorCallback from torch import optim @@ -224,6 +225,14 @@ def parse_args(): help="Path to validation data file", ) + parser.add_argument( + "-p", + "--progress-bars", + action="store_true", + default=False, + help="Enable progress bars for epoch and trial progress", + ) + # Study parameters parser.add_argument("--study-trials", type=int, default=2000, help="Number of trials to run") @@ -460,6 +469,7 @@ def test_one_by_one(model_filename): class KudosDataset(Dataset): def __init__(self, filename): + logger.debug(f"Loading dataset from {filename}") self.data = [] self.labels = [] @@ -487,6 +497,7 @@ def __init__(self, filename): self.labels.append(payload["time_to_generate"]) self.labels = torch.tensor(self.labels).float() self.mixed_data = torch.stack(self.data) + logger.debug(f"Loaded {len(self.data)} samples") @classmethod def payload_to_tensor(cls, payload): @@ -580,6 +591,9 @@ def __getitem__(self, idx): def create_sequential_model(trial, layer_sizes, input_size, output_size=1): + logger.debug( + f"Creating model with input size {input_size}, output size {output_size}, and layer sizes {layer_sizes}", + ) # Define the layer sizes layer_sizes = [input_size] + layer_sizes + [output_size] @@ -601,7 +615,8 @@ def create_sequential_model(trial, layer_sizes, input_size, output_size=1): return nn.Sequential(*layers) -def objective(trial): +def objective(trial: optuna.Trial) -> float: + logger.debug(f"Starting trial {trial.number}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") trial.set_user_attr("name", "predict_kudos") @@ -656,7 +671,13 @@ def objective(trial): patience = trial.suggest_int("patience", MIN_PATIENCE, MAX_PATIENCE) if USE_PATIENCE else 0 epochs_since_best = 0 - pbar = tqdm(range(NUM_EPOCHS), desc="Training Progress") + if ENABLE_PROGRESS_BARS: + pbar = tqdm(range(NUM_EPOCHS), desc="Training Progress") + else: + pbar = range(NUM_EPOCHS) + + logger.debug(f"Starting training for {NUM_EPOCHS} epochs") + for epoch in pbar: # Train the model model.train() @@ -694,15 +715,22 @@ def objective(trial): # Stop early, no improvement in awhile break - pbar.set_description( + info_str = ( f"input_size={input_size}, layers={layers}, output_size={output_size} " - f"batch_size={batch}, optimizer={optimizer_name}, lr={lr}, weight_decay={weight_decay}", + f"batch_size={batch}, optimizer={optimizer_name}, lr={lr}, weight_decay={weight_decay}" ) - - pbar.set_postfix( - loss=total_loss, - best_loss=best_loss, - epochs_since_best=epochs_since_best, + if ENABLE_PROGRESS_BARS: + pbar.set_description(info_str) + logger.debug(info_str) + + if ENABLE_PROGRESS_BARS: + pbar.set_postfix( + loss=total_loss, + best_loss=best_loss, + epochs_since_best=epochs_since_best, + ) + logger.debug( + f"Epoch: {epoch}, Loss: {total_loss}, Best Loss: {best_loss}, Epochs since best: {epochs_since_best}", ) # reload the best performing model we found @@ -712,7 +740,9 @@ def objective(trial): filename = f"kudos_models/kudos-{STUDY_VERSION}-{trial.number}.ckpt" with open(filename, "wb") as outfile: pickle.dump(model.to("cpu"), outfile) + logger.debug(f"Saved model to {filename}") + logger.debug(f"Finished trial {trial.number} with best_loss {best_loss}") return best_loss @@ -768,7 +798,7 @@ def main(): callbacks=[TerminatorCallback(terminator)], ) except (KeyboardInterrupt, AbortTrial): - print("Trial process aborted") + logger.warning("Trial process aborted") # fig = optuna.visualization.plot_terminator_improvement( # study, # plot_error=True, @@ -777,17 +807,17 @@ def main(): # ) # fig.write_image(f"kudos_model_improvement_evaluator_{STUDY_VERSION}") # Print the best hyperparameters - print("Best trial:") + logger.info("Best trial:") trial = study.best_trial - print("Value: ", trial.value) - print("Params: ") + logger.info("Value: ", trial.value) + logger.info("Params: ") for key, value in trial.params.items(): - print(f"{key}: {value}") + logger.info(f"{key}: {value}") # Calculate the accuracy of the best model best_filename = f"kudos_models/kudos-{STUDY_VERSION}-{trial.number}.ckpt" # model = test_one_by_one(best_filename) - print(f"Best model file is: {best_filename}") + logger.info(f"Best model file is: {best_filename}") if __name__ == "__main__": @@ -803,6 +833,12 @@ def main(): VALIDATION_DATA_FILENAME = args.validation_data NUMBER_OF_STUDY_TRIALS = args.study_trials STUDY_VERSION = args.study_version + ENABLE_PROGRESS_BARS = args.progress_bars + + if ENABLE_PROGRESS_BARS: + logger.remove() + logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True) + logger.add("logs/train.log", rotation="10 MB") # Create SQLite connection string DB_CONNECTION_STRING = f"sqlite:///{args.db_path}"