diff --git a/scripts/dwi_gp_estimation_error_analysis.py b/scripts/dwi_gp_estimation_error_analysis.py index cca074c4..289a29fb 100644 --- a/scripts/dwi_gp_estimation_error_analysis.py +++ b/scripts/dwi_gp_estimation_error_analysis.py @@ -47,6 +47,7 @@ def cross_validate( X: np.ndarray, y: np.ndarray, cv: int, + n_repeats: int, gpr: EddyMotionGPR, ) -> dict[int, list[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]]: """ @@ -60,6 +61,8 @@ def cross_validate( DWI signal. cv : :obj:`int` number of folds + n_repeats : :obj:`int` + Number of times the cross-validator needs to be repeated. gpr : obj:`~eddymotion.model._sklearn.EddyMotionGPR` The eddymotion Gaussian process regressor object. @@ -70,7 +73,7 @@ def cross_validate( """ - rkf = RepeatedKFold(n_splits=cv, n_repeats=120 // cv) + rkf = RepeatedKFold(n_splits=cv, n_repeats=n_repeats) scores = cross_val_score(gpr, X, y, scoring="neg_root_mean_squared_error", cv=rkf) return scores @@ -161,7 +164,7 @@ def main() -> None: scores = defaultdict(list, {}) for n in args.kfold: for i in range(args.repeats): - cv_scores = -1.0 * cross_validate(X, y.T, n, gpr) + cv_scores = -1.0 * cross_validate(X, y.T, n, np.max(args.kfold) // n, gpr) scores["rmse"] += cv_scores.tolist() scores["repeat"] += [i] * len(cv_scores) scores["n_folds"] += [n] * len(cv_scores)