diff --git a/autoemulate/compare.py b/autoemulate/compare.py index 41cdd676..aeaf969d 100644 --- a/autoemulate/compare.py +++ b/autoemulate/compare.py @@ -74,11 +74,6 @@ def setup( self.n_jobs = n_jobs self.logger = configure_logging(log_to_file=log_to_file) self.is_set_up = True - self.scores_df = pd.DataFrame( - columns=["model", "metric", "fold", "score"] - ).astype( - {"model": "object", "metric": "object", "fold": "int64", "score": "float64"} - ) self.cv_results = {} def compare(self): @@ -92,6 +87,13 @@ def compare(self): if not self.is_set_up: raise RuntimeError("Must run setup() before compare()") + # Freshly initialise scores dataframe when running compare() + self.scores_df = pd.DataFrame( + columns=["model", "metric", "fold", "score"] + ).astype( + {"model": "object", "metric": "object", "fold": "int64", "score": "float64"} + ) + for model in self.models: # search for best hyperparameters if self.hyperparameter_search: @@ -136,7 +138,7 @@ def cross_validate(self, model): self.logger.info(f"Parameters: {model.get_params()}") # Cross-validate - cv = cross_validate( + cv_results = cross_validate( model, self.X, self.y, @@ -146,12 +148,29 @@ def cross_validate(self, model): return_estimator=True, return_indices=True, ) + self._update_scores_df(model_name, cv_results) + def _update_scores_df(self, model_name, cv_results): + """Updates the scores dataframe with the results of the cross-validation. + + Parameters + ---------- + model_name : str + Name of the model. + cv_results : dict + Results of the cross-validation. + + Returns + ------- + scores_df : pandas.DataFrame + Dataframe containing the scores for each model, metric and fold. + + """ # Gather scores from each metric # Initialise scores dataframe - for key in cv.keys(): + for key in cv_results.keys(): if key.startswith("test_"): - for fold, score in enumerate(cv[key]): + for fold, score in enumerate(cv_results[key]): self.scores_df.loc[len(self.scores_df.index)] = { "model": model_name, "metric": key.split("test_", 1)[1], @@ -159,7 +178,7 @@ def cross_validate(self, model): "score": score, } # save results for plotting etc. - self.cv_results[model_name] = cv + self.cv_results[model_name] = cv_results def hyperparam_search(self, model): """Performs hyperparameter search for a given model.