Skip to content

Commit

Permalink
Add optuna_engine_kwargs parameter to Autotuner_BaseEstimator and Aut…
Browse files Browse the repository at this point in the history
…o_Classifier classes
  • Loading branch information
RichieHakim committed Mar 20, 2024
1 parent 70ced70 commit 11c18e7
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions bnpm/automatic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class Autotuner_BaseEstimator:
The name of the study in the storage. Should only be ``None`` if
``optuna_storage_url`` is ``None``, else it should be a string to a
valid study name.
optuna_engine_kwargs (Optional[Dict[str, Any]]):
Additional keyword arguments to pass to the SQLAlchemy engine.
wandb_project (Optional[str]):
The name of the Weights and Biases project to log to. If ``None``,
then no logging is done. Assumes that a wandb.init() has already
Expand Down Expand Up @@ -126,6 +128,7 @@ def __init__(
verbose=True,
optuna_storage_url: Optional[str] = None,
optuna_storage_name: Optional[str] = None,
optuna_engine_kwargs: Optional[Dict[str, Any]] = None,
wandb_project: Optional[str] = None,
):
"""
Expand Down Expand Up @@ -161,6 +164,7 @@ def __init__(
self.verbose = verbose
self.optuna_storage_url = optuna_storage_url
self.optuna_storage_name = optuna_storage_name
self.optuna_engine_kwargs = optuna_engine_kwargs

self.wandb_project = wandb_project

Expand Down Expand Up @@ -241,7 +245,6 @@ def _objective(self, trial: optuna.trial.Trial) -> float:
for param, (param_linked, fn) in self.params_linked.items():
kwargs_model[param] = fn(kwargs_model[param_linked])


# Train the model
loss_train_all, loss_test_all, loss_all = [], [], []
for ii in range(self.n_repeats):
Expand Down Expand Up @@ -339,12 +342,16 @@ def fit(self) -> Union[sklearn.base.BaseEstimator, Optional[Dict[str, Any]]]:
optuna.logging.set_verbosity(optuna.logging.DEBUG)

# Initialize an Optuna study
storage = optuna.storages.RDBStorage(
url=self.optuna_storage_url,
engine_kwargs=self.optuna_engine_kwargs,
)
self.study = optuna.create_study(
direction="minimize",
pruner=optuna.pruners.MedianPruner(n_startup_trials=self.n_startup),
sampler=optuna.samplers.TPESampler(n_startup_trials=self.n_startup),
study_name='Autotuner' if self.optuna_storage_name is None else self.optuna_storage_name,
storage=self.optuna_storage_url,
storage=storage,
load_if_exists=True,
)

Expand Down Expand Up @@ -776,6 +783,8 @@ class Auto_Classifier(Autotuner_BaseEstimator):
The name of the study in the storage. Should only be ``None`` if
``optuna_storage_url`` is ``None``, else it should be a string to a
valid study name.
optuna_engine_kwargs (Optional[Dict[str, Any]]):
Additional keyword arguments to pass to the SQLAlchemy engine.
wandb_project (Optional[str]):
The name of the Weights and Biases project to log to. If ``None``,
then no logging is done. Assumes that a wandb.init() has already
Expand Down Expand Up @@ -852,6 +861,7 @@ def __init__(
verbose: bool = True,
optuna_storage_url: Optional[str] = None,
optuna_storage_name: Optional[str] = None,
optuna_engine_kwargs: Optional[Dict[str, Any]] = None,
wandb_project: Optional[str] = None,
) -> None:
"""
Expand Down Expand Up @@ -906,6 +916,7 @@ def __init__(
verbose=verbose,
optuna_storage_url=optuna_storage_url,
optuna_storage_name=optuna_storage_name,
optuna_engine_kwargs=optuna_engine_kwargs,
wandb_project=wandb_project,
)

Expand Down Expand Up @@ -1057,6 +1068,8 @@ class Auto_Regression(Autotuner_BaseEstimator):
The name of the study in the storage. Should only be ``None`` if
``optuna_storage_url`` is ``None``, else it should be a string to a
valid study name.
optuna_engine_kwargs (Optional[Dict[str, Any]]):
Additional keyword arguments to pass to the SQLAlchemy engine.
wandb_project (Optional[str]):
The name of the Weights and Biases project to log to. If ``None``,
then no logging is done. Assumes that a wandb.init() has already
Expand Down Expand Up @@ -1131,6 +1144,7 @@ def __init__(
verbose: bool = True,
optuna_storage_url: Optional[str] = None,
optuna_storage_name: Optional[str] = None,
optuna_engine_kwargs: Optional[Dict[str, Any]] = None,
wandb_project: Optional[str] = None,
) -> None:
"""
Expand Down Expand Up @@ -1178,6 +1192,7 @@ def __init__(
verbose=verbose,
optuna_storage_url=optuna_storage_url,
optuna_storage_name=optuna_storage_name,
optuna_engine_kwargs=optuna_engine_kwargs,
wandb_project=wandb_project,
)

Expand Down

0 comments on commit 11c18e7

Please sign in to comment.