Skip to content

Commit

Permalink
Merge pull request #51 from alan-turing-institute/fix-hyperparam-search
Browse files Browse the repository at this point in the history
Fix hyperparam search
  • Loading branch information
mastoffel authored Nov 15, 2023
2 parents 5f5d099 + 2be73b1 commit 3c8afda
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
11 changes: 8 additions & 3 deletions autoemulate/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def compare(self):

for i, model in enumerate(self.models):
updated_model = (
self._get_best_hyperparams_hyperparams(i, model)
self._get_best_hyperparams(i, model)
if self.hyperparameter_search
else model
)
Expand Down Expand Up @@ -243,7 +243,12 @@ def _get_best_hyperparams(self, model_index, model):
"""
# Perform hyperparameter search and update model
hyperparam_searcher = HyperparamSearch(
self.X, self.y, self.cv, self.n_jobs, self.logger
X=self.X,
y=self.y,
cv=self.cv,
param_grid=None,
n_jobs=self.n_jobs,
logger=self.logger,
)
updated_model = hyperparam_searcher.search(model)

Expand Down Expand Up @@ -299,7 +304,7 @@ def _get_best_model(self, metric="r2"):

return best_model

def print_scores(self, model=None):
def print_results(self, model=None):
# check if model is in self.models
if model is not None:
model_names = [type(model).__name__ for model in self.models]
Expand Down
3 changes: 2 additions & 1 deletion autoemulate/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def fetch_dataset(dataset_name, train_test=False, test_size=0.2, random_state=42
"four_chamber" is the 4-chamber model without LH sampling.
"circ_adapt" is the circulatory adaptation model also from LH sampling.
train_test : bool, optional
If True, returns the dataset split into training and testing sets.
If True, returns the dataset split into training and testing sets,
X_train, X_test, y_train, y_test.
If False, returns the entire dataset. Default is False.
test_size : float, optional
The proportion of the dataset to include in the test split.
Expand Down
4 changes: 2 additions & 2 deletions autoemulate/hyperparam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, X, y, cv, n_jobs, param_grid=None, logger=None):
self.cv = cv
self.n_jobs = n_jobs
self.param_grid = param_grid
self.logger = logger or logging.getLogger(__name__)
self.logger = logger
self.best_params = {}

def search(self, model):
Expand Down Expand Up @@ -64,7 +64,7 @@ def prepare_param_grid(model, param_grid=None):
"""Prepares the parameter grid with prefixed parameters."""
if param_grid is None:
param_grid = model.named_steps["model"].get_grid_params()
print(f"param_grid: {param_grid}")
# print(f"param_grid: {param_grid}")
return {f"model__{key}": value for key, value in param_grid.items()}

@staticmethod
Expand Down

0 comments on commit 3c8afda

Please sign in to comment.