From 0f3e0489f4c1fb363ca1c85bb7f2564455bad5dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Sat, 26 Oct 2024 13:42:41 -0400 Subject: [PATCH] ENH: Compute the number of CV repetitions dynamically Compute the number of CV repetitions dynamically by setting the number to the maximum number of folds requested by the user. --- scripts/dwi_gp_estimation_error_analysis.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/scripts/dwi_gp_estimation_error_analysis.py b/scripts/dwi_gp_estimation_error_analysis.py index cca074c4..289a29fb 100644 --- a/scripts/dwi_gp_estimation_error_analysis.py +++ b/scripts/dwi_gp_estimation_error_analysis.py @@ -47,6 +47,7 @@ def cross_validate( X: np.ndarray, y: np.ndarray, cv: int, + n_repeats: int, gpr: EddyMotionGPR, ) -> dict[int, list[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]]: """ @@ -60,6 +61,8 @@ def cross_validate( DWI signal. cv : :obj:`int` number of folds + n_repeats : :obj:`int` + Number of times the cross-validator needs to be repeated. gpr : obj:`~eddymotion.model._sklearn.EddyMotionGPR` The eddymotion Gaussian process regressor object. @@ -70,7 +73,7 @@ def cross_validate( """ - rkf = RepeatedKFold(n_splits=cv, n_repeats=120 // cv) + rkf = RepeatedKFold(n_splits=cv, n_repeats=n_repeats) scores = cross_val_score(gpr, X, y, scoring="neg_root_mean_squared_error", cv=rkf) return scores @@ -161,7 +164,7 @@ def main() -> None: scores = defaultdict(list, {}) for n in args.kfold: for i in range(args.repeats): - cv_scores = -1.0 * cross_validate(X, y.T, n, gpr) + cv_scores = -1.0 * cross_validate(X, y.T, n, np.max(args.kfold) // n, gpr) scores["rmse"] += cv_scores.tolist() scores["repeat"] += [i] * len(cv_scores) scores["n_folds"] += [n] * len(cv_scores)