From a15442c3761c33d2eee981868f9c8ac59eb1685b Mon Sep 17 00:00:00 2001 From: mastoffel Date: Tue, 13 Aug 2024 11:42:24 +0100 Subject: [PATCH] fix get_model --- autoemulate/compare.py | 35 +++++++++++++++++++---------------- tests/test_compare.py | 20 +++++++++++++++++--- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/autoemulate/compare.py b/autoemulate/compare.py index 72bdb4bb..4c9883cb 100644 --- a/autoemulate/compare.py +++ b/autoemulate/compare.py @@ -254,14 +254,12 @@ def compare(self): ) # get best model - best_model_name, self.best_model = self.get_model( - rank=1, metric="r2", name=True - ) + self.best_model = self.get_model(rank=1, metric="r2") return self.best_model - def get_model(self, rank=1, metric="r2", name=False): - """Get a fitted model based on it's rank in the comparison. + def get_model(self, rank=1, metric="r2", name=None): + """Get a fitted model based on it rank in the comparison or its name. Parameters ---------- @@ -269,8 +267,8 @@ def get_model(self, rank=1, metric="r2", name=False): Rank of the model to return. Defaults to 1, which is the best model, 2 is the second best, etc. metric : str Metric to use for determining the best model. - name : bool - If True, returns tuple of model name and model. If False, returns only the model. + name : str + Name of the model to return. Returns ------- @@ -278,27 +276,32 @@ def get_model(self, rank=1, metric="r2", name=False): Model fitted on full data. """ - if not hasattr(self, "scores_df"): + # get model by name + if name is not None: + if not isinstance(name, str): + raise ValueError("Name must be a string") + for model in self.models: + if get_model_name(model) == name or get_short_model_name(model) == name: + return model + raise ValueError(f"Model {name} not found") + + # check that comparison has been run + if not hasattr(self, "scores_df") and name is None: raise RuntimeError("Must run compare() before get_model()") - # get average scores across folds - means = get_mean_scores(self.scores_df, metric) # get model by rank + means = get_mean_scores(self.scores_df, metric) + if (rank > len(means)) or (rank < 1): raise RuntimeError(f"Rank must be >= 1 and <= {len(means)}") chosen_model_name = means.iloc[rank - 1]["model"] - # get best model: for model in self.models: if get_model_name(model) == chosen_model_name: chosen_model = model break - # check whether the model is fitted - check_is_fitted(chosen_model) - - if name: - return chosen_model_name, chosen_model + # check_is_fitted(chosen_model) return chosen_model def refit_model(self, model): diff --git a/tests/test_compare.py b/tests/test_compare.py index 70590ccb..256b2dbe 100644 --- a/tests/test_compare.py +++ b/tests/test_compare.py @@ -12,6 +12,7 @@ from autoemulate.experimental_design import ExperimentalDesign from autoemulate.experimental_design import LatinHypercube from autoemulate.metrics import METRIC_REGISTRY +from autoemulate.utils import get_model_name @pytest.fixture() @@ -114,14 +115,27 @@ def test__get_metrics(ae): # -----------------------test get_model-------------------# -def test_get_model(ae_run): - # Test getting the best model +def test_get_model_by_name(ae_run): + model = ae_run.get_model(name="RandomForest") + assert get_model_name(model) == "RandomForest" + + +def test_get_model_by_short_name(ae_run): + model = ae_run.get_model(name="rf") + assert get_model_name(model) == "RandomForest" + + +def test_get_model_by_invalid_name(ae_run): + with pytest.raises(ValueError): + ae_run.get_model(name="invalid_name") + + +def test_get_model_by_rank(ae_run): model = ae_run.get_model(rank=1) assert model is not None def test_get_model_with_invalid_rank(ae_run): - # Test getting a model with an invalid rank with pytest.raises(RuntimeError): ae_run.get_model(rank=0)