Skip to content

Commit

Permalink
(MA) implemented a hyperparameter search matrix for 27 common options…
Browse files Browse the repository at this point in the history
…; meant mostly for tuning to smaller datasets
  • Loading branch information
amkrajewski committed Mar 28, 2024
1 parent 2faea84 commit 47e3fbf
Showing 1 changed file with 159 additions and 2 deletions.
161 changes: 159 additions & 2 deletions pysipfenn/core/modelAdjusters.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
from typing import Union, Literal, Tuple, List
from typing import Union, Literal, Tuple, List, Dict
from copy import deepcopy
import gc

import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
import plotly.express as px
import plotly.graph_objects as go
from pysipfenn.core.pysipfenn import Calculator

class LocalAdjuster:
Expand Down Expand Up @@ -257,7 +258,6 @@ def adjust(
if verbose:
print(f'Train: {transferLosses[-1]:.4f} | Epoch: 0/{epochs}')


for epoch in range(epochs):
model.train()
for data, target in dataloaderTrain:
Expand Down Expand Up @@ -305,7 +305,160 @@ def adjust(

return self.adjustedModel, transferLosses, validationLosses

def matrixHyperParameterSearch(
self,
validation: float = 0.2,
epochs: int = 100,
batchSize: int = 32,
lossFunction: Literal["MSE", "MAE"] = "MAE",
learningRates: Tuple[float] = (1e-6, 1e-5, 1e-4),
optimizers: Tuple[Literal["Adam", "AdamW", "Adamax", "RMSprop"]] = ("Adam", "AdamW", "Adamax"),
weightDecays: Tuple[float] = (1e-5, 1e-4, 1e-3),
verbose: bool = True,
plot: bool = True
) -> Tuple[torch.nn.Module, Dict[str, Union[float, str]]]:
"""
Performs a grid search over the hyperparameters provided to find the best combination. By default, it will
plot the training history with plotly in your browser, and (b) print the best hyperparameters found. If the
ClearML platform was set to be used for logging (at the class initialization), the results will be uploaded
there as well. If the default values are used, it will test 27 combinations of learning rates, optimizers, and
weight decays. The method will then adjust the model to the best hyperparameters found, corresponding to the
lowest validation loss if validation is used, or the lowest training loss if validation is not used
(``validation=0``). Note that the validation is used by default.
Args:
validation: Same as in the ``adjust`` method. Default is ``0.2``.
epochs: Same as in the ``adjust`` method. Default is ``100``.
batchSize: Same as in the ``adjust`` method. Default is ``32``.
lossFunction: Same as in the ``adjust`` method. Default is ``MAE``, i.e. Mean Absolute Error or L1 loss.
learningRates: Tuple of floats with the learning rates to be tested. Default is ``(1e-6, 1e-5, 1e-4)``. See
the ``adjust`` method for more information.
optimizers: Tuple of strings with the optimizers to be tested. Default is ``("Adam", "AdamW", "Adamax")``. See
the ``adjust`` method for more information.
weightDecays: Tuple of floats with the weight decays to be tested. Default is ``(1e-5, 1e-4, 1e-3)``. See
the ``adjust`` method for more information.
verbose: Same as in the ``adjust`` method. Default is ``True``.
plot: Whether to plot the training history after all the combinations are tested. Default is ``True``.
"""
if verbose:
print("Starting the hyperparameter search...")

bestModel: torch.nn.Module = None
bestTrainingLoss: float = np.inf
bestValidationLoss: float = np.inf
bestHyperparameters: Dict[str, Union[float, str, None]] = {
"learningRate": None,
"optimizer": None,
"weightDecay": None,
"epochs": None
}

trainLossHistory: List[List[float]] = []
validationLossHistory: List[List[float]] = []
labels: List[str] = []

for learningRate in learningRates:
for optimizer in optimizers:
for weightDecay in weightDecays:
labels.append(f"LR: {learningRate} | OPT: {optimizer} | WD: {weightDecay}")
model, trainingLoss, validationLoss = self.adjust(
validation=validation,
learningRate=learningRate,
epochs=epochs,
batchSize=batchSize,
optimizer=optimizer,
weightDecay=weightDecay,
lossFunction=lossFunction,
verbose=True
)
trainLossHistory.append(trainingLoss)
validationLossHistory.append(validationLoss)
if validation > 0:
localBestValidationLoss, bestEpoch = min((val, idx) for idx, val in enumerate(validationLoss))
if localBestValidationLoss < bestValidationLoss:
print(f"New best model found with LR: {learningRate}, OPT: {optimizer}, WD: {weightDecay}, "
f"Epoch: {bestEpoch + 1}/{epochs} | Train: {trainingLoss[bestEpoch]:.4f} | "
f"Validation: {localBestValidationLoss:.4f}")
del bestModel
gc.collect()
bestModel = model
bestTrainingLoss = trainingLoss[bestEpoch]
bestValidationLoss = localBestValidationLoss
bestHyperparameters["learningRate"] = learningRate
bestHyperparameters["optimizer"] = optimizer
bestHyperparameters["weightDecay"] = weightDecay
bestHyperparameters["epochs"] = bestEpoch + 1
else:
print(f"Model with LR: {learningRate}, OPT: {optimizer}, WD: {weightDecay} did not improve.")
else:
localBestTrainingLoss, bestEpoch = min((val, idx) for idx, val in enumerate(trainingLoss))
if localBestTrainingLoss < bestTrainingLoss:
print(f"New best model found with LR: {learningRate}, OPT: {optimizer}, WD: {weightDecay}, "
f"Epoch: {bestEpoch + 1}/{epochs} | Train: {localBestTrainingLoss:.4f}")
del bestModel
gc.collect()
bestModel = model
bestTrainingLoss = localBestTrainingLoss
bestHyperparameters["learningRate"] = learningRate
bestHyperparameters["optimizer"] = optimizer
bestHyperparameters["weightDecay"] = weightDecay
bestHyperparameters["epochs"] = bestEpoch + 1
else:
print(f"Model with LR: {learningRate}, OPT: {optimizer}, WD: {weightDecay} did not improve.")

if verbose:
print(f"\n\nBest model found with LR: {bestHyperparameters['learningRate']}, OPT: {bestHyperparameters['optimizer']}, "
f"WD: {bestHyperparameters['weightDecay']}, Epoch: {bestHyperparameters['epochs']}")
if validation > 0:
print(f"Train: {bestTrainingLoss:.4f} | Validation: {bestValidationLoss:.4f}")
else:
print(f"Train: {bestTrainingLoss:.4f}")
assert bestModel is not None, "The best model was not found. Something went wrong during the hyperparameter search."
self.adjustedModel = bestModel
del bestModel
gc.collect()

if plot:
fig1 = go.Figure()
for idx, label in enumerate(labels):
fig1.add_trace(
go.Scatter(
x=np.arange(epochs+1),
y=trainLossHistory[idx],
mode='lines+markers',
name=label)

)
fig1.update_layout(
title="Training Loss History",
xaxis_title="Epoch",
yaxis_title="Loss",
legend_title="Hyperparameters",
showlegend=True,
template="plotly_white"
)
fig1.show()
if validation > 0:
fig2 = go.Figure()
for idx, label in enumerate(labels):
fig2.add_trace(
go.Scatter(
x=np.arange(epochs+1),
y=validationLossHistory[idx],
mode='lines+markers',
name=label)
)
fig2.update_layout(
title="Validation Loss History",
xaxis_title="Epoch",
yaxis_title="Loss",
legend_title="Hyperparameters",
showlegend=True,
template="plotly_white"
)
fig2.show()

return self.adjustedModel, bestHyperparameters



Expand All @@ -317,3 +470,7 @@ class OPTIMADEAdjuster(LocalAdjuster):
settings used by that database or focusing its attention to specific chemistry like, for instance, all compounds of
Sn and all perovskites. It accepts OPTIMADE query as an input and then operates based on the ``LocalAdjuster`` class.
"""




0 comments on commit 47e3fbf

Please sign in to comment.