Skip to content

Commit

Permalink
lower max learning rate value to avoid gradient explode. set verbose …
Browse files Browse the repository at this point in the history
…and error_score to both Searcher
  • Loading branch information
bryanlimy committed Mar 13, 2024
1 parent f70d018 commit 06ba365
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
6 changes: 3 additions & 3 deletions autoemulate/emulators/neural_networks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,18 @@ def get_grid_params(search_type: str = "random"):
nn.Sigmoid,
nn.GELU,
],
"optimizer": [torch.optim.AdamW, torch.optim.LBFGS],
"optimizer": [torch.optim.SGD, torch.optim.AdamW, torch.optim.LBFGS],
"optimizer__weight_decay": (1 / 10 ** np.arange(1, 9)).tolist(),
}
match search_type:
case "random":
param_space |= {
"lr": loguniform(1e-06, 1e-2),
"lr": loguniform(1e-6, 1e-4),
}
case "bayes":
param_space |= {
"optimizer": Categorical(param_space["optimizer"]),
"lr": Real(1e-06, 1e-2, prior="log-uniform"),
"lr": Real(1e-6, 1e-4, prior="log-uniform"),
}
case _:
raise ValueError(f"Invalid search type: {search_type}")
Expand Down
6 changes: 3 additions & 3 deletions autoemulate/emulators/neural_networks/rbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,18 +324,18 @@ def get_grid_params(search_type: str = "random"):
rbf_inverse_quadratic,
rbf_inverse_multiquadric,
],
"optimizer": [torch.optim.AdamW, torch.optim.SGD],
"optimizer": [torch.optim.SGD, torch.optim.AdamW, torch.optim.LBFGS],
"optimizer__weight_decay": (1 / 10 ** np.arange(1, 9)).tolist(),
}
match search_type:
case "random":
param_space |= {
"lr": loguniform(1e-06, 1e-2),
"lr": loguniform(1e-06, 1e-3),
}
case "bayes":
param_space |= {
"optimizer": Categorical(param_space["optimizer"]),
"lr": Real(1e-06, 1e-2, prior="log-uniform"),
"lr": Real(1e-06, 1e-3, prior="log-uniform"),
}
case _:
raise ValueError(f"Invalid search type: {search_type}")
Expand Down
14 changes: 11 additions & 3 deletions autoemulate/hyperparam_searching.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def _optimize_params(
param_space=None,
n_jobs=None,
logger=None,
error_score=np.nan,
verbose=0,
):
"""Performs hyperparameter search for the provided model.
Expand All @@ -49,7 +51,11 @@ def _optimize_params(
n_jobs : int
Number of jobs to run in parallel.
logger : logging.Logger
Logger instance.
Logger instance
error_score: 'raise' or numeric
Value to assign to the score if an error occurs in estimator fitting.
verbose: int
Verbosity level for the searcher
Returns
-------
Expand All @@ -68,7 +74,8 @@ def _optimize_params(
cv=cv,
n_jobs=n_jobs,
refit=True,
verbose=0,
error_score=error_score,
verbose=verbose,
)
# Bayes search
elif search_type == "bayes":
Expand All @@ -79,7 +86,8 @@ def _optimize_params(
cv=cv,
n_jobs=n_jobs,
refit=True,
verbose=0,
error_score=error_score,
verbose=verbose,
)
elif search_type == "grid":
raise NotImplementedError("Grid search not available yet.")
Expand Down

0 comments on commit 06ba365

Please sign in to comment.