-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor cv and update scores df into separate functions
refactor cv and update scores df into separate functions switch arguments add tests for cv
- Loading branch information
Showing
4 changed files
with
142 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import pandas as pd | ||
from sklearn.metrics import make_scorer | ||
from sklearn.model_selection import cross_validate | ||
|
||
from autoemulate.utils import get_model_name | ||
|
||
|
||
def run_cv(X, y, cv, model, metrics, n_jobs, logger): | ||
# Get model name | ||
model_name = get_model_name(model) | ||
|
||
# The metrics we want to use for cross-validation | ||
scorers = {metric.__name__: make_scorer(metric) for metric in metrics} | ||
|
||
logger.info(f"Cross-validating {model_name}...") | ||
logger.info(f"Parameters: {model.named_steps['model'].get_params()}") | ||
|
||
try: | ||
# Cross-validate | ||
cv_results = cross_validate( | ||
model, | ||
X, | ||
y, | ||
cv=cv, | ||
scoring=scorers, | ||
n_jobs=n_jobs, | ||
return_estimator=True, | ||
return_indices=True, | ||
) | ||
|
||
except Exception as e: | ||
logger.error(f"Failed to cross-validate {model_name}") | ||
logger.error(e) | ||
|
||
return cv_results | ||
|
||
|
||
def update_scores_df(scores_df, model, cv_results): | ||
"""Updates the scores dataframe with the results of the cross-validation. | ||
Parameters | ||
---------- | ||
scores_df : pandas.DataFrame | ||
DataFrame with columns "model", "metric", "fold", "score". | ||
model_name : str | ||
Name of the model. | ||
cv_results : dict | ||
Results of the cross-validation. | ||
Returns | ||
------- | ||
None | ||
Modifies the self.scores_df DataFrame in-place. | ||
""" | ||
# Gather scores from each metric | ||
# Initialise scores dataframe | ||
for key in cv_results.keys(): | ||
if key.startswith("test_"): | ||
for fold, score in enumerate(cv_results[key]): | ||
scores_df.loc[len(scores_df.index)] = { | ||
"model": get_model_name(model), | ||
"metric": key.split("test_", 1)[1], | ||
"fold": fold, | ||
"score": score, | ||
} | ||
return scores_df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import logging | ||
from typing import List | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
from sklearn.datasets import make_regression | ||
from sklearn.model_selection import KFold | ||
from sklearn.pipeline import Pipeline | ||
from sklearn.preprocessing import StandardScaler | ||
|
||
from autoemulate.compare import AutoEmulate | ||
from autoemulate.cross_validate import run_cv | ||
from autoemulate.cross_validate import update_scores_df | ||
from autoemulate.emulators import RandomForest | ||
from autoemulate.metrics import METRIC_REGISTRY | ||
|
||
# import make_regression_data | ||
|
||
X, y = make_regression(n_samples=20, n_features=2, random_state=0) | ||
cv = KFold(n_splits=5, shuffle=True) | ||
model = Pipeline([("scaler", StandardScaler()), ("model", RandomForest())]) | ||
metrics = [metric for metric in METRIC_REGISTRY.values()] | ||
n_jobs = 1 | ||
logger = logging.getLogger(__name__) | ||
scores_df = pd.DataFrame(columns=["model", "metric", "fold", "score"]).astype( | ||
{"model": "object", "metric": "object", "fold": "int64", "score": "float64"} | ||
) | ||
|
||
|
||
@pytest.fixture() | ||
def cv_results(): | ||
return run_cv(X, y, cv, model, metrics, n_jobs, logger) | ||
|
||
|
||
def test_cv(cv_results): | ||
assert isinstance(cv_results, dict) | ||
# check that it contains scores | ||
assert "test_r2" in cv_results.keys() | ||
assert "test_rsme" in cv_results.keys() | ||
|
||
assert isinstance(cv_results["test_r2"], np.ndarray) | ||
assert isinstance(cv_results["test_rsme"], np.ndarray) | ||
|
||
assert len(cv_results["test_r2"]) == 5 | ||
assert len(cv_results["test_rsme"]) == 5 | ||
|
||
|
||
def test_update_scores_df(cv_results): | ||
scores_df_new = update_scores_df(scores_df, model, cv_results) | ||
assert isinstance(scores_df_new, pd.DataFrame) | ||
|
||
assert scores_df_new.shape[0] == 10 | ||
assert scores_df_new.shape[1] == 4 | ||
assert scores_df_new["model"][0] == "RandomForest" |