Skip to content

Commit

Permalink
Merge pull request #246 from alan-turing-institute/fix-model-subset
Browse files Browse the repository at this point in the history
load core models by default
  • Loading branch information
mastoffel authored Sep 23, 2024
2 parents c687681 + 6600a43 commit e823a63
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 2 deletions.
4 changes: 3 additions & 1 deletion autoemulate/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def setup(
self.train_idxs, self.test_idxs = _split_data(
self.X, test_size=test_set_size, random_state=42
)
self.model_names = self.model_registry.get_model_names(model_subset)
self.model_names = self.model_registry.get_model_names(
model_subset, is_core=True
)
self.models = _process_models(
model_registry=self.model_registry,
model_subset=list(self.model_names.keys()),
Expand Down
8 changes: 7 additions & 1 deletion autoemulate/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def register_model(self, model_name, model_class, is_core=False):
if is_core:
self.core_model_names.append(model_name)

def get_model_names(self, model_subset=None):
def get_model_names(self, model_subset=None, is_core=False):
"""Get a dictionary of (all) model names and their short names
Parameters
Expand Down Expand Up @@ -60,6 +60,12 @@ def get_model_names(self, model_subset=None):
for k, v in model_names.items()
if k in model_subset or v in model_subset
}

if is_core:
model_names = {
k: v for k, v in model_names.items() if k in self.core_model_names
}

return model_names

def get_core_models(self):
Expand Down
7 changes: 7 additions & 0 deletions tests/test_model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ def test_get_model_names_invalid_str(model_registry):
model_registry.get_model_names(model_subset="invalid")


def test_get_model_names_is_core(model_registry):
model_names = model_registry.get_model_names(is_core=True)
assert isinstance(model_names, dict)
assert len(model_names) == 2
assert model_names["RadialBasisFunctions"] == "rbf"


# check get_models -------------------------------------------
def test_get_models(model_registry):
models = model_registry.get_models()
Expand Down

0 comments on commit e823a63

Please sign in to comment.