Skip to content

Commit

Permalink
change printing text
Browse files Browse the repository at this point in the history
  • Loading branch information
mastoffel committed Feb 9, 2024
1 parent 8c73294 commit 5ec602f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 11 deletions.
11 changes: 9 additions & 2 deletions autoemulate/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def setup(
param_search=False,
param_search_type="random",
param_search_iters=20,
param_search_test_size=0.2,
scale=True,
scaler=StandardScaler(),
reduce_dim=False,
Expand Down Expand Up @@ -94,7 +95,7 @@ def setup(
"""
self.X, self.y = self._check_input(X, y)
self.train_idxs, self.test_idxs = split_data(
self.X, test_size=0.2, param_search=param_search
self.X, test_size=param_search_test_size, param_search=param_search
)
self.models = get_and_process_models(
MODEL_REGISTRY,
Expand Down Expand Up @@ -306,7 +307,13 @@ def print_results(self, sort_by="r2", model=None):
The name of the model to print. If None, the best fold from each model will be printed.
If a model name is provided, the scores for that model across all folds will be printed.
"""
print_cv_results(self.models, self.scores_df, model=model, sort_by=sort_by)
print_cv_results(
self.models,
self.scores_df,
model=model,
sort_by=sort_by,
param_search=self.param_search,
)

def plot_results(
self,
Expand Down
32 changes: 23 additions & 9 deletions autoemulate/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from autoemulate.utils import get_model_name


def print_cv_results(models, scores_df, model=None, sort_by="r2"):
def print_cv_results(models, scores_df, model=None, sort_by="r2", param_search=False):
"""Print cv results.
Parameters
Expand All @@ -26,12 +26,26 @@ def print_cv_results(models, scores_df, model=None, sort_by="r2"):
f"Model {model} not found. Available models are: {model_names}"
)
if model is None:
means = get_mean_scores(scores_df, metric=sort_by)
print("Average scores across all models:")
print(means)
if param_search:
means = get_mean_scores(scores_df, metric=sort_by)
print("Test score for each model:")
print(means)
else:
means = get_mean_scores(scores_df, metric=sort_by)
print("Average scores across all models:")
print(means)
else:
scores = scores_df[scores_df["model"] == model].pivot(
index="fold", columns="metric", values="score"
)
print(f"Scores for {model} across all folds:")
print(scores)
if param_search:
scores = scores_df[scores_df["model"] == model].pivot(
index="fold", columns="metric", values="score"
)
# drop metric column
# get index of scores
print(f"Test score for {model}:")
print(scores)
else:
scores = scores_df[scores_df["model"] == model].pivot(
index="fold", columns="metric", values="score"
)
print(f"Scores for {model} across all folds:")
print(scores)

0 comments on commit 5ec602f

Please sign in to comment.