Skip to content

Commit

Permalink
Merge pull request #232 from alan-turing-institute/fix-get_model
Browse files Browse the repository at this point in the history
fix get_model
  • Loading branch information
mastoffel authored Aug 13, 2024
2 parents 9345877 + a15442c commit cd08f63
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 19 deletions.
35 changes: 19 additions & 16 deletions autoemulate/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,51 +254,54 @@ def compare(self):
)

# get best model
best_model_name, self.best_model = self.get_model(
rank=1, metric="r2", name=True
)
self.best_model = self.get_model(rank=1, metric="r2")

return self.best_model

def get_model(self, rank=1, metric="r2", name=False):
"""Get a fitted model based on it's rank in the comparison.
def get_model(self, rank=1, metric="r2", name=None):
"""Get a fitted model based on it rank in the comparison or its name.
Parameters
----------
rank : int
Rank of the model to return. Defaults to 1, which is the best model, 2 is the second best, etc.
metric : str
Metric to use for determining the best model.
name : bool
If True, returns tuple of model name and model. If False, returns only the model.
name : str
Name of the model to return.
Returns
-------
model : object
Model fitted on full data.
"""

if not hasattr(self, "scores_df"):
# get model by name
if name is not None:
if not isinstance(name, str):
raise ValueError("Name must be a string")
for model in self.models:
if get_model_name(model) == name or get_short_model_name(model) == name:
return model
raise ValueError(f"Model {name} not found")

# check that comparison has been run
if not hasattr(self, "scores_df") and name is None:
raise RuntimeError("Must run compare() before get_model()")

# get average scores across folds
means = get_mean_scores(self.scores_df, metric)
# get model by rank
means = get_mean_scores(self.scores_df, metric)

if (rank > len(means)) or (rank < 1):
raise RuntimeError(f"Rank must be >= 1 and <= {len(means)}")
chosen_model_name = means.iloc[rank - 1]["model"]

# get best model:
for model in self.models:
if get_model_name(model) == chosen_model_name:
chosen_model = model
break

# check whether the model is fitted
check_is_fitted(chosen_model)

if name:
return chosen_model_name, chosen_model
# check_is_fitted(chosen_model)
return chosen_model

def refit_model(self, model):
Expand Down
20 changes: 17 additions & 3 deletions tests/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from autoemulate.experimental_design import ExperimentalDesign
from autoemulate.experimental_design import LatinHypercube
from autoemulate.metrics import METRIC_REGISTRY
from autoemulate.utils import get_model_name


@pytest.fixture()
Expand Down Expand Up @@ -114,14 +115,27 @@ def test__get_metrics(ae):


# -----------------------test get_model-------------------#
def test_get_model(ae_run):
# Test getting the best model
def test_get_model_by_name(ae_run):
model = ae_run.get_model(name="RandomForest")
assert get_model_name(model) == "RandomForest"


def test_get_model_by_short_name(ae_run):
model = ae_run.get_model(name="rf")
assert get_model_name(model) == "RandomForest"


def test_get_model_by_invalid_name(ae_run):
with pytest.raises(ValueError):
ae_run.get_model(name="invalid_name")


def test_get_model_by_rank(ae_run):
model = ae_run.get_model(rank=1)
assert model is not None


def test_get_model_with_invalid_rank(ae_run):
# Test getting a model with an invalid rank
with pytest.raises(RuntimeError):
ae_run.get_model(rank=0)

Expand Down

0 comments on commit cd08f63

Please sign in to comment.