Skip to content

Commit

Permalink
make sure scores_df is always freshly created
Browse files Browse the repository at this point in the history
  • Loading branch information
mastoffel committed Nov 10, 2023
1 parent e7cd2d9 commit 00c28f7
Showing 1 changed file with 28 additions and 9 deletions.
37 changes: 28 additions & 9 deletions autoemulate/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,6 @@ def setup(
self.n_jobs = n_jobs
self.logger = configure_logging(log_to_file=log_to_file)
self.is_set_up = True
self.scores_df = pd.DataFrame(
columns=["model", "metric", "fold", "score"]
).astype(
{"model": "object", "metric": "object", "fold": "int64", "score": "float64"}
)
self.cv_results = {}

def compare(self):
Expand All @@ -92,6 +87,13 @@ def compare(self):
if not self.is_set_up:
raise RuntimeError("Must run setup() before compare()")

# Freshly initialise scores dataframe when running compare()
self.scores_df = pd.DataFrame(
columns=["model", "metric", "fold", "score"]
).astype(
{"model": "object", "metric": "object", "fold": "int64", "score": "float64"}
)

for model in self.models:
# search for best hyperparameters
if self.hyperparameter_search:
Expand Down Expand Up @@ -136,7 +138,7 @@ def cross_validate(self, model):
self.logger.info(f"Parameters: {model.get_params()}")

# Cross-validate
cv = cross_validate(
cv_results = cross_validate(
model,
self.X,
self.y,
Expand All @@ -146,20 +148,37 @@ def cross_validate(self, model):
return_estimator=True,
return_indices=True,
)
self._update_scores_df(model_name, cv_results)

def _update_scores_df(self, model_name, cv_results):
"""Updates the scores dataframe with the results of the cross-validation.
Parameters
----------
model_name : str
Name of the model.
cv_results : dict
Results of the cross-validation.
Returns
-------
scores_df : pandas.DataFrame
Dataframe containing the scores for each model, metric and fold.
"""
# Gather scores from each metric
# Initialise scores dataframe
for key in cv.keys():
for key in cv_results.keys():
if key.startswith("test_"):
for fold, score in enumerate(cv[key]):
for fold, score in enumerate(cv_results[key]):
self.scores_df.loc[len(self.scores_df.index)] = {
"model": model_name,
"metric": key.split("test_", 1)[1],
"fold": fold,
"score": score,
}
# save results for plotting etc.
self.cv_results[model_name] = cv
self.cv_results[model_name] = cv_results

def hyperparam_search(self, model):
"""Performs hyperparameter search for a given model.
Expand Down

0 comments on commit 00c28f7

Please sign in to comment.