Skip to content

Commit

Permalink
chore/fix/style: docstrings, typehints, readability
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Nov 22, 2024
1 parent e651fe0 commit 1afe4ea
Showing 1 changed file with 76 additions and 48 deletions.
124 changes: 76 additions & 48 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
# type: ignore
#
# Train a predictive model from horde payload inputs to predict inference time.
#
# Supports multi-processing, just run this multiple times and the processes will
# automatically work together on the training. We are training with torch and searching
# through network hyper parameters using Optuna.
#
# Requires two input files (both exactly the same format) which can be created by enabling
# the SAVE_KUDOS_TRAINING_DATA constant in the worker.
# - inference-time-data.json
# - inference-time-data-validation.json
#
# The output is a series of model checkpoints, "kudos_models/kudos-X-n.ckpt" Where n is the
# number of the trial and X is the study version. Once the best trial number is identified
# simply select the appropriate file.
#
# The stand-alone class in examples/kudos.py is the code to actually use the model.
#
# Requires also a local mysql database named "optuna" and assumes it can connect
# with user "root" password "root". Change to your needs.
#
# For visualisation with optuna dashboard:
# optuna-dashboard mysql://root:root@localhost/optuna
#
# This is a quick hack to assist with kudos calculation.
"""Train a predictive model from horde payload inputs to predict inference time.
Supports multi-processing, just run this multiple times and the processes will
automatically work together on the training. We are training with torch and searching
through network hyper parameters using Optuna.
Requires two input files (both exactly the same format) which can be created by enabling
the SAVE_KUDOS_TRAINING_DATA constant in the worker.
- inference-time-data.json
- inference-time-data-validation.json
The output is a series of model checkpoints, "kudos_models/kudos-X-n.ckpt" Where n is the
number of the trial and X is the study version. Once the best trial number is identified
simply select the appropriate file.
The stand-alone class in examples/kudos.py is the code to actually use the model.
Requires also a local mysql database named "optuna" and assumes it can connect
with user "root" password "root". Change to your needs.
For visualisation with optuna dashboard:
optuna-dashboard mysql://root:root@localhost/optuna
This is a quick hack to assist with kudos calculation.
"""

import argparse
import json
import math
Expand Down Expand Up @@ -253,22 +253,32 @@ def signal_handler(sig, frame):
signal.signal(signal.SIGTERM, signal_handler)


# This is an example of how to use the final model, pass in a horde payload, get back a predicted time in seconds
def payload_to_time(model, payload):
def payload_to_time(model: nn.Module, payload: dict) -> float:
"""Return the predicted time in seconds for a given horde payload."""
inputs = KudosDataset.payload_to_tensor(payload).squeeze()
with torch.no_grad():
output = model(inputs)
output: torch.Tensor = model(inputs)
return round(float(output.item()), 2)


# This is how to load the model required above
def load_model(model_filename):
def load_model(model_filename) -> nn.Module:
with open(model_filename, "rb") as infile:
return pickle.load(infile)


class PercentageLoss(torch.nn.Module):
def forward(self, predicted, actual):
"""Torch module to calculate the percentage loss between two tensors."""

def forward(self, predicted: torch.Tensor, actual: torch.Tensor) -> torch.Tensor:
"""Calculate the percentage loss between the predicted and actual time.
Args:
predicted (torch.Tensor): The predicted time in seconds
actual (torch.Tensor): The actual time in seconds
Returns:
torch.Tensor: The percentage loss
"""
diff = torch.abs(actual - predicted)
max_val = torch.max(actual, predicted)
# We make it an order of magnitude higher, so that it appears clearer on the graphs
Expand Down Expand Up @@ -303,7 +313,7 @@ def flatten_dict(d: dict, parent_key: str = "") -> dict[str, Any]:
"post_processing_order",
}

items = []
items: list[tuple[str, Any]] = []
for k, v in d.items():
new_key = f"{parent_key}.{k}" if parent_key else k
if isinstance(v, dict):
Expand Down Expand Up @@ -349,7 +359,7 @@ def are_values_similar(val1: Any, val2: Any) -> bool:
if not values:
continue

value_groups = defaultdict(int)
value_groups: dict[str, float] = defaultdict(int)
processed_values = set()

for i, val1 in enumerate(values):
Expand Down Expand Up @@ -500,7 +510,7 @@ def __init__(self, filename):
logger.debug(f"Loaded {len(self.data)} samples")

@classmethod
def payload_to_tensor(cls, payload):
def payload_to_tensor(cls, payload: dict) -> torch.Tensor:
payload = payload["sdk_api_job_info"]
p = payload["payload"]
data = []
Expand Down Expand Up @@ -590,15 +600,20 @@ def __getitem__(self, idx):
return self.mixed_data[idx], self.labels[idx]


def create_sequential_model(trial, layer_sizes, input_size, output_size=1):
def create_sequential_model(
trial: optuna.Trial,
layer_sizes: list[int],
input_size: int,
output_size: int = 1,
) -> nn.Sequential:
logger.debug(
f"Creating model with input size {input_size}, output size {output_size}, and layer sizes {layer_sizes}",
)
# Define the layer sizes
layer_sizes = [input_size] + layer_sizes + [output_size]

# Create the layers and activation functions
layers = []
layers: list[nn.Module] = []
for i in range(len(layer_sizes) - 1):
layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1]))
if i < len(layer_sizes) - 2:
Expand Down Expand Up @@ -639,7 +654,7 @@ def objective(trial: optuna.Trial) -> float:
lr = trial.suggest_float("learning_rate", MIN_LEARNING_RATE, MAX_LEARNING_RATE, log=True)
weight_decay = trial.suggest_float("weight_decay", MIN_WEIGHT_DECAY, MAX_WEIGHT_DECAY, log=True)

optimizer = None
optimizer: optim.Optimizer | None = None

if optimizer_name == "Adam":
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
Expand All @@ -662,15 +677,22 @@ def objective(trial: optuna.Trial) -> float:
validate_dataset = KudosDataset(VALIDATION_DATA_FILENAME)
validate_loader = DataLoader(validate_dataset, batch_size=64, shuffle=True)

if validate_loader is None:
raise Exception("No validation data")
# Loss function
# criterion = nn.HuberLoss()
criterion = PercentageLoss()

total_loss = None
best_epoch = best_loss = best_state_dict = None
total_loss = 0.0
best_epoch = 0
best_loss = float("inf")
best_state_dict = None

patience = trial.suggest_int("patience", MIN_PATIENCE, MAX_PATIENCE) if USE_PATIENCE else 0
epochs_since_best = 0

pbar: range | tqdm

if ENABLE_PROGRESS_BARS:
pbar = tqdm(range(NUM_EPOCHS), desc="Training Progress")
else:
Expand All @@ -681,26 +703,28 @@ def objective(trial: optuna.Trial) -> float:
for epoch in pbar:
# Train the model
model.train()
data: torch.Tensor
labels: torch.Tensor
for data, labels in train_loader:
data = data.to(device)
labels = labels.to(device)
optimizer.zero_grad()
labels = labels.unsqueeze(1)
outputs = model(data)
loss = criterion(outputs, labels)
loss: torch.Tensor = criterion(outputs, labels)
loss.backward()
optimizer.step()

model.eval()
total_loss = 0
total_loss = 0.0
with torch.no_grad():
for data, labels in validate_loader:
data = data.to(device)
labels = labels.to(device)
outputs = model(data)
labels = labels.unsqueeze(1)
loss = criterion(outputs, labels)
total_loss += loss
total_loss += float(loss)

total_loss /= len(validate_loader)
total_loss = round(float(total_loss), 4)
Expand All @@ -719,22 +743,26 @@ def objective(trial: optuna.Trial) -> float:
f"input_size={input_size}, layers={layers}, output_size={output_size} "
f"batch_size={batch}, optimizer={optimizer_name}, lr={lr}, weight_decay={weight_decay}"
)
if ENABLE_PROGRESS_BARS:
pbar.set_description(info_str)
logger.debug(info_str)

if ENABLE_PROGRESS_BARS:
if ENABLE_PROGRESS_BARS and isinstance(pbar, tqdm):
pbar.set_description(info_str)
pbar.set_postfix(
loss=total_loss,
best_loss=best_loss,
epochs_since_best=epochs_since_best,
)

logger.debug(info_str)
logger.debug(
f"Epoch: {epoch}, Loss: {total_loss}, Best Loss: {best_loss}, Epochs since best: {epochs_since_best}",
)

# reload the best performing model we found
model.load_state_dict(best_state_dict)
if best_state_dict is not None:
logger.debug(f"Reloading best model from epoch {best_epoch}")
model.load_state_dict(best_state_dict)
else:
logger.error("No best model found")

# Pickle it as we'll forget the model architecture
filename = f"kudos_models/kudos-{STUDY_VERSION}-{trial.number}.ckpt"
Expand All @@ -746,7 +774,7 @@ def objective(trial: optuna.Trial) -> float:
return best_loss


def main():
def main() -> None:

if args.test_model:
low_predictions = test_one_by_one(args.test_model)
Expand Down

0 comments on commit 1afe4ea

Please sign in to comment.