Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve user interface #195

Merged
merged 5 commits into from
Mar 1, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 85 additions & 2 deletions autoemulate/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,13 @@ def setup(
self.n_jobs = n_jobs
self.logger = _configure_logging(log_to_file=log_to_file)
self.is_set_up = True
self.dim_reducer = dim_reducer
self.reduce_dim = reduce_dim
self.folds = folds
self.cv_results = {}

self.print_settings()

def _check_input(self, X, y):
"""Checks and possibly converts the input data.

Expand Down Expand Up @@ -198,7 +203,12 @@ def compare(self):
self.scores_df = pd.DataFrame(
columns=["model", "metric", "fold", "score"]
).astype(
{"model": "object", "metric": "object", "fold": "int64", "score": "float64"}
{
"model": "object",
"metric": "object",
"fold": "int64",
"score": "float64",
}
)

for i in range(len(self.models)):
Expand Down Expand Up @@ -368,6 +378,74 @@ def load_model(self, path=None):

return serialiser._load_model(path)

def print_settings(self):
kallewesterling marked this conversation as resolved.
Show resolved Hide resolved
if not self.is_set_up:
raise RuntimeError("Must run setup() before print_settings()")
return

models = "\n- " + "\n- ".join(
[
x[1].__class__.__name__
for pipeline in self.models
for x in pipeline.steps
if x[0] == "model"
]
)
metrics = "\n- " + "\n- ".join([metric.__name__ for metric in self.metrics])

settings = pd.DataFrame(
[
str(self.X.shape),
str(self.y.shape),
str(self.train_idxs.shape),
str(self.test_idxs.shape),
kallewesterling marked this conversation as resolved.
Show resolved Hide resolved
str(self.param_search),
str(self.search_type),
str(self.param_search_iters),
str(self.scale),
str(
self.scaler.__class__.__name__ if self.scaler is not None else None
),
str(self.reduce_dim),
str(
self.dim_reducer.__class__.__name__
if self.dim_reducer is not None
else None
),
str(self.cv.__class__.__name__ if self.cv is not None else "None"),
str(self.folds),
str(self.n_jobs),
],
index=[
"Simulation input shape (X)",
"Simulation output shape (y)",
"Training dataset shape (train_idxs)",
"Test dataset shape (test_idxs)",
kallewesterling marked this conversation as resolved.
Show resolved Hide resolved
"Do hyperparameter search (param_search)",
"Type of hyperparameter search (search_type)",
"# sampled parameter settings (param_search_iters)",
"Scale data before fitting (scale)",
"Scaler (scaler)",
"Dimensionality reduction before fitting (reduce_dim)",
"Dimensionality reduction method (dim_reducer)",
"Cross-validation strategy (cv)",
"# folds (folds)",
"# parallel jobs (n_jobs)",
],
columns=["Values"],
)

settings_str = settings.to_string(index=True, header=False)
width = len(settings_str.split("\n")[0])

print("AutoEmulate is set up with the following settings:")
print("-" * width)
print(settings_str)
print("-" * width)
print("Models:" + models)
print("-" * width)
print("Metrics:" + metrics)

def print_results(self, model=None, sort_by="r2"):
"""Print cv results.

Expand Down Expand Up @@ -470,5 +548,10 @@ def plot_model(self, model, plot="standard", n_cols=2, figsize=None):
Number of columns in the plot grid for multi-output. Default is 2.
"""
_plot_model(
model, self.X[self.test_idxs], self.y[self.test_idxs], plot, n_cols, figsize
model,
self.X[self.test_idxs],
self.y[self.test_idxs],
plot,
n_cols,
figsize,
)
Loading