Skip to content

Commit

Permalink
fix: extra logging and optional progress bars to train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Nov 22, 2024
1 parent 2abaf7d commit e651fe0
Showing 1 changed file with 51 additions and 15 deletions.
66 changes: 51 additions & 15 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import optunahub
import torch
import torch.nn as nn
from loguru import logger
from optuna.distributions import CategoricalDistribution, FloatDistribution, IntDistribution
from optuna.terminator import EMMREvaluator, MedianErrorEvaluator, Terminator, TerminatorCallback
from torch import optim
Expand Down Expand Up @@ -224,6 +225,14 @@ def parse_args():
help="Path to validation data file",
)

parser.add_argument(
"-p",
"--progress-bars",
action="store_true",
default=False,
help="Enable progress bars for epoch and trial progress",
)

# Study parameters
parser.add_argument("--study-trials", type=int, default=2000, help="Number of trials to run")

Expand Down Expand Up @@ -460,6 +469,7 @@ def test_one_by_one(model_filename):

class KudosDataset(Dataset):
def __init__(self, filename):
logger.debug(f"Loading dataset from {filename}")
self.data = []
self.labels = []

Expand Down Expand Up @@ -487,6 +497,7 @@ def __init__(self, filename):
self.labels.append(payload["time_to_generate"])
self.labels = torch.tensor(self.labels).float()
self.mixed_data = torch.stack(self.data)
logger.debug(f"Loaded {len(self.data)} samples")

@classmethod
def payload_to_tensor(cls, payload):
Expand Down Expand Up @@ -580,6 +591,9 @@ def __getitem__(self, idx):


def create_sequential_model(trial, layer_sizes, input_size, output_size=1):
logger.debug(
f"Creating model with input size {input_size}, output size {output_size}, and layer sizes {layer_sizes}",
)
# Define the layer sizes
layer_sizes = [input_size] + layer_sizes + [output_size]

Expand All @@ -601,7 +615,8 @@ def create_sequential_model(trial, layer_sizes, input_size, output_size=1):
return nn.Sequential(*layers)


def objective(trial):
def objective(trial: optuna.Trial) -> float:
logger.debug(f"Starting trial {trial.number}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

trial.set_user_attr("name", "predict_kudos")
Expand Down Expand Up @@ -656,7 +671,13 @@ def objective(trial):
patience = trial.suggest_int("patience", MIN_PATIENCE, MAX_PATIENCE) if USE_PATIENCE else 0
epochs_since_best = 0

pbar = tqdm(range(NUM_EPOCHS), desc="Training Progress")
if ENABLE_PROGRESS_BARS:
pbar = tqdm(range(NUM_EPOCHS), desc="Training Progress")
else:
pbar = range(NUM_EPOCHS)

logger.debug(f"Starting training for {NUM_EPOCHS} epochs")

for epoch in pbar:
# Train the model
model.train()
Expand Down Expand Up @@ -694,15 +715,22 @@ def objective(trial):
# Stop early, no improvement in awhile
break

pbar.set_description(
info_str = (
f"input_size={input_size}, layers={layers}, output_size={output_size} "
f"batch_size={batch}, optimizer={optimizer_name}, lr={lr}, weight_decay={weight_decay}",
f"batch_size={batch}, optimizer={optimizer_name}, lr={lr}, weight_decay={weight_decay}"
)

pbar.set_postfix(
loss=total_loss,
best_loss=best_loss,
epochs_since_best=epochs_since_best,
if ENABLE_PROGRESS_BARS:
pbar.set_description(info_str)
logger.debug(info_str)

if ENABLE_PROGRESS_BARS:
pbar.set_postfix(
loss=total_loss,
best_loss=best_loss,
epochs_since_best=epochs_since_best,
)
logger.debug(
f"Epoch: {epoch}, Loss: {total_loss}, Best Loss: {best_loss}, Epochs since best: {epochs_since_best}",
)

# reload the best performing model we found
Expand All @@ -712,7 +740,9 @@ def objective(trial):
filename = f"kudos_models/kudos-{STUDY_VERSION}-{trial.number}.ckpt"
with open(filename, "wb") as outfile:
pickle.dump(model.to("cpu"), outfile)
logger.debug(f"Saved model to {filename}")

logger.debug(f"Finished trial {trial.number} with best_loss {best_loss}")
return best_loss


Expand Down Expand Up @@ -768,7 +798,7 @@ def main():
callbacks=[TerminatorCallback(terminator)],
)
except (KeyboardInterrupt, AbortTrial):
print("Trial process aborted")
logger.warning("Trial process aborted")
# fig = optuna.visualization.plot_terminator_improvement(
# study,
# plot_error=True,
Expand All @@ -777,17 +807,17 @@ def main():
# )
# fig.write_image(f"kudos_model_improvement_evaluator_{STUDY_VERSION}")
# Print the best hyperparameters
print("Best trial:")
logger.info("Best trial:")
trial = study.best_trial
print("Value: ", trial.value)
print("Params: ")
logger.info("Value: ", trial.value)
logger.info("Params: ")
for key, value in trial.params.items():
print(f"{key}: {value}")
logger.info(f"{key}: {value}")

# Calculate the accuracy of the best model
best_filename = f"kudos_models/kudos-{STUDY_VERSION}-{trial.number}.ckpt"
# model = test_one_by_one(best_filename)
print(f"Best model file is: {best_filename}")
logger.info(f"Best model file is: {best_filename}")


if __name__ == "__main__":
Expand All @@ -803,6 +833,12 @@ def main():
VALIDATION_DATA_FILENAME = args.validation_data
NUMBER_OF_STUDY_TRIALS = args.study_trials
STUDY_VERSION = args.study_version
ENABLE_PROGRESS_BARS = args.progress_bars

if ENABLE_PROGRESS_BARS:
logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True)
logger.add("logs/train.log", rotation="10 MB")

# Create SQLite connection string
DB_CONNECTION_STRING = f"sqlite:///{args.db_path}"
Expand Down

0 comments on commit e651fe0

Please sign in to comment.