diff --git a/autoemulate/compare.py b/autoemulate/compare.py index 7214cc60..b661c24e 100644 --- a/autoemulate/compare.py +++ b/autoemulate/compare.py @@ -116,7 +116,9 @@ def setup( self.train_idxs, self.test_idxs = _split_data( self.X, test_size=test_set_size, random_state=42 ) - self.model_names = self.model_registry.get_model_names(model_subset) + self.model_names = self.model_registry.get_model_names( + model_subset, is_core=True + ) self.models = _process_models( model_registry=self.model_registry, model_subset=list(self.model_names.keys()), diff --git a/autoemulate/model_registry.py b/autoemulate/model_registry.py index 36e9d97e..c50d14c1 100644 --- a/autoemulate/model_registry.py +++ b/autoemulate/model_registry.py @@ -15,7 +15,7 @@ def register_model(self, model_name, model_class, is_core=False): if is_core: self.core_model_names.append(model_name) - def get_model_names(self, model_subset=None): + def get_model_names(self, model_subset=None, is_core=False): """Get a dictionary of (all) model names and their short names Parameters @@ -60,6 +60,12 @@ def get_model_names(self, model_subset=None): for k, v in model_names.items() if k in model_subset or v in model_subset } + + if is_core: + model_names = { + k: v for k, v in model_names.items() if k in self.core_model_names + } + return model_names def get_core_models(self): diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py index 6796a464..e2336a77 100644 --- a/tests/test_model_registry.py +++ b/tests/test_model_registry.py @@ -90,6 +90,13 @@ def test_get_model_names_invalid_str(model_registry): model_registry.get_model_names(model_subset="invalid") +def test_get_model_names_is_core(model_registry): + model_names = model_registry.get_model_names(is_core=True) + assert isinstance(model_names, dict) + assert len(model_names) == 2 + assert model_names["RadialBasisFunctions"] == "rbf" + + # check get_models ------------------------------------------- def test_get_models(model_registry): models = model_registry.get_models()